{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from comet_ml import Experiment\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from fastai.vision import *\n",
    "import torch\n",
    "from torchsummary import summary\n",
    "torch.cuda.set_device(1)\n",
    "torch.manual_seed(1)\n",
    "torch.cuda.manual_seed(1)\n",
    "\n",
    "# stage should be in 0 to 5 and for 0, use -1 (this is due to inconsistency in the model generated by PyTorch)\n",
    "hyper_params = {\n",
    "    \"stage\": 0,\n",
    "    \"repeated\": 1,\n",
    "    \"num_classes\": 10,\n",
    "    \"batch_size\": 64,\n",
    "    \"num_epochs\": 100,\n",
    "    \"learning_rate\": 1e-4\n",
    "}\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def save_torch(name:str, tensor):\n",
    "    new = tensor.clone()\n",
    "    np.save(name, new.detach().cpu().numpy())\n",
    "    \n",
    "def load_np_torch(name:str):\n",
    "    return torch.from_numpy(np.load(name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.IMAGENETTE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "tfms = get_transforms(do_flip=False)\n",
    "data = ImageDataBunch.from_folder(path, train = 'train', valid = 'val', bs = hyper_params[\"batch_size\"], size = 224).normalize(imagenet_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ImageDataBunch;\n",
       "\n",
       "Train: LabelList (12894 items)\n",
       "x: ImageList\n",
       "Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)\n",
       "y: CategoryList\n",
       "n01440764,n01440764,n01440764,n01440764,n01440764\n",
       "Path: /home/akshay/.fastai/data/imagenette;\n",
       "\n",
       "Valid: LabelList (500 items)\n",
       "x: ImageList\n",
       "Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224),Image (3, 224, 224)\n",
       "y: CategoryList\n",
       "n01440764,n01440764,n01440764,n01440764,n01440764\n",
       "Path: /home/akshay/.fastai/data/imagenette;\n",
       "\n",
       "Test: None"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model on GPU\n",
      "0.0.weight torch.Size([64, 3, 7, 7])\n",
      "True\n",
      "0.2.weight torch.Size([64])\n",
      "True\n",
      "0.2.bias torch.Size([64])\n",
      "True\n",
      "2.0.0.weight torch.Size([64, 64, 3, 3])\n",
      "False\n",
      "2.0.2.weight torch.Size([64])\n",
      "False\n",
      "2.0.2.bias torch.Size([64])\n",
      "False\n",
      "2.1.conv1.0.weight torch.Size([64, 64, 3, 3])\n",
      "False\n",
      "2.1.conv1.2.weight torch.Size([64])\n",
      "False\n",
      "2.1.conv1.2.bias torch.Size([64])\n",
      "False\n",
      "3.0.0.weight torch.Size([128, 64, 3, 3])\n",
      "False\n",
      "3.0.2.weight torch.Size([128])\n",
      "False\n",
      "3.0.2.bias torch.Size([128])\n",
      "False\n",
      "3.1.conv1.0.weight torch.Size([128, 128, 3, 3])\n",
      "False\n",
      "3.1.conv1.2.weight torch.Size([128])\n",
      "False\n",
      "3.1.conv1.2.bias torch.Size([128])\n",
      "False\n",
      "4.0.0.weight torch.Size([256, 128, 3, 3])\n",
      "False\n",
      "4.0.2.weight torch.Size([256])\n",
      "False\n",
      "4.0.2.bias torch.Size([256])\n",
      "False\n",
      "4.1.conv1.0.weight torch.Size([256, 256, 3, 3])\n",
      "False\n",
      "4.1.conv1.2.weight torch.Size([256])\n",
      "False\n",
      "4.1.conv1.2.bias torch.Size([256])\n",
      "False\n",
      "5.0.0.weight torch.Size([512, 256, 3, 3])\n",
      "False\n",
      "5.0.2.weight torch.Size([512])\n",
      "False\n",
      "5.0.2.bias torch.Size([512])\n",
      "False\n",
      "5.1.conv1.0.weight torch.Size([512, 512, 3, 3])\n",
      "False\n",
      "5.1.conv1.2.weight torch.Size([512])\n",
      "False\n",
      "5.1.conv1.2.bias torch.Size([512])\n",
      "False\n",
      "8.weight torch.Size([256, 1024])\n",
      "False\n",
      "8.bias torch.Size([256])\n",
      "False\n",
      "9.weight torch.Size([10, 256])\n",
      "False\n",
      "9.bias torch.Size([10])\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "learn = cnn_learner(data, models.resnet34, metrics = accuracy)\n",
    "learn = learn.load('unfreeze_imagenet_bs64')\n",
    "learn.freeze()\n",
    "# learn.summary()\n",
    "\n",
    "class Flatten(nn.Module) :\n",
    "    def forward(self, input):\n",
    "        return input.view(input.size(0), -1)\n",
    "\n",
    "def conv2(ni, nf) : \n",
    "    return conv_layer(ni, nf, stride = 2)\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, nf):\n",
    "        super().__init__()\n",
    "        self.conv1 = conv_layer(nf,nf)\n",
    "        \n",
    "    def forward(self, x): \n",
    "        return (x + self.conv1(x))\n",
    "\n",
    "def conv_and_res(ni, nf): \n",
    "    return nn.Sequential(conv2(ni, nf), ResBlock(nf))\n",
    "\n",
    "def conv_(nf) : \n",
    "    return nn.Sequential(conv_layer(nf, nf), ResBlock(nf))\n",
    "    \n",
    "net = nn.Sequential(\n",
    "    conv_layer(3, 64, ks = 7, stride = 2, padding = 3),\n",
    "    nn.MaxPool2d(3, 2, padding = 1),\n",
    "    conv_(64),\n",
    "    conv_and_res(64, 128),\n",
    "    conv_and_res(128, 256),\n",
    "    conv_and_res(256, 512),\n",
    "    AdaptiveConcatPool2d(),\n",
    "    Flatten(),\n",
    "    nn.Linear(2 * 512, 256),\n",
    "    nn.Linear(256, hyper_params[\"num_classes\"])\n",
    ")\n",
    "\n",
    "if torch.cuda.is_available() : \n",
    "    net = net.cuda()\n",
    "    print('Model on GPU')\n",
    "    \n",
    "for name, param in net.named_parameters() : \n",
    "    print(name, param.shape)\n",
    "    param.requires_grad = False\n",
    "    if name[0] == str(hyper_params['stage'] + 1) and hyper_params['stage'] != 0 :\n",
    "        param.requires_grad = True\n",
    "    elif name[0] == str(hyper_params['stage']) : \n",
    "        param.requires_grad = True\n",
    "    print(param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 112, 112]           9,408\n",
      "              ReLU-2         [-1, 64, 112, 112]               0\n",
      "       BatchNorm2d-3         [-1, 64, 112, 112]             128\n",
      "         MaxPool2d-4           [-1, 64, 56, 56]               0\n",
      "            Conv2d-5           [-1, 64, 56, 56]          36,864\n",
      "              ReLU-6           [-1, 64, 56, 56]               0\n",
      "       BatchNorm2d-7           [-1, 64, 56, 56]             128\n",
      "            Conv2d-8           [-1, 64, 56, 56]          36,864\n",
      "              ReLU-9           [-1, 64, 56, 56]               0\n",
      "      BatchNorm2d-10           [-1, 64, 56, 56]             128\n",
      "         ResBlock-11           [-1, 64, 56, 56]               0\n",
      "           Conv2d-12          [-1, 128, 28, 28]          73,728\n",
      "             ReLU-13          [-1, 128, 28, 28]               0\n",
      "      BatchNorm2d-14          [-1, 128, 28, 28]             256\n",
      "           Conv2d-15          [-1, 128, 28, 28]         147,456\n",
      "             ReLU-16          [-1, 128, 28, 28]               0\n",
      "      BatchNorm2d-17          [-1, 128, 28, 28]             256\n",
      "         ResBlock-18          [-1, 128, 28, 28]               0\n",
      "           Conv2d-19          [-1, 256, 14, 14]         294,912\n",
      "             ReLU-20          [-1, 256, 14, 14]               0\n",
      "      BatchNorm2d-21          [-1, 256, 14, 14]             512\n",
      "           Conv2d-22          [-1, 256, 14, 14]         589,824\n",
      "             ReLU-23          [-1, 256, 14, 14]               0\n",
      "      BatchNorm2d-24          [-1, 256, 14, 14]             512\n",
      "         ResBlock-25          [-1, 256, 14, 14]               0\n",
      "           Conv2d-26            [-1, 512, 7, 7]       1,179,648\n",
      "             ReLU-27            [-1, 512, 7, 7]               0\n",
      "      BatchNorm2d-28            [-1, 512, 7, 7]           1,024\n",
      "           Conv2d-29            [-1, 512, 7, 7]       2,359,296\n",
      "             ReLU-30            [-1, 512, 7, 7]               0\n",
      "      BatchNorm2d-31            [-1, 512, 7, 7]           1,024\n",
      "         ResBlock-32            [-1, 512, 7, 7]               0\n",
      "AdaptiveMaxPool2d-33            [-1, 512, 1, 1]               0\n",
      "AdaptiveAvgPool2d-34            [-1, 512, 1, 1]               0\n",
      "AdaptiveConcatPool2d-35           [-1, 1024, 1, 1]               0\n",
      "          Flatten-36                 [-1, 1024]               0\n",
      "           Linear-37                  [-1, 256]         262,400\n",
      "           Linear-38                   [-1, 10]           2,570\n",
      "================================================================\n",
      "Total params: 4,996,938\n",
      "Trainable params: 9,536\n",
      "Non-trainable params: 4,987,402\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 40.03\n",
      "Params size (MB): 19.06\n",
      "Estimated Total Size (MB): 59.67\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# x, y = next(iter(data.train_dl))\n",
    "# net(torch.autograd.Variable(x).cuda())\n",
    "summary(net, (3, 224, 224))\n",
    "# print(learn.summary())\n",
    "# net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SaveFeatures :\n",
    "    def __init__(self, m) : \n",
    "        self.handle = m.register_forward_hook(self.hook_fn)\n",
    "    def hook_fn(self, m, inp, outp) : \n",
    "        self.features = outp\n",
    "    def remove(self) :\n",
    "        self.handle.remove()\n",
    "        \n",
    "# saving outputs of all Basic Blocks\n",
    "mdl = learn.model\n",
    "sf = [SaveFeatures(m) for m in [mdl[0][2], mdl[0][4], mdl[0][5], mdl[0][6], mdl[0][7]]]\n",
    "sf2 = [SaveFeatures(m) for m in [net[0], net[2], net[3], net[4], net[5]]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 112, 112]           9,408\n",
      "       BatchNorm2d-2         [-1, 64, 112, 112]             128\n",
      "              ReLU-3         [-1, 64, 112, 112]               0\n",
      "         MaxPool2d-4           [-1, 64, 56, 56]               0\n",
      "            Conv2d-5           [-1, 64, 56, 56]          36,864\n",
      "       BatchNorm2d-6           [-1, 64, 56, 56]             128\n",
      "              ReLU-7           [-1, 64, 56, 56]               0\n",
      "            Conv2d-8           [-1, 64, 56, 56]          36,864\n",
      "       BatchNorm2d-9           [-1, 64, 56, 56]             128\n",
      "             ReLU-10           [-1, 64, 56, 56]               0\n",
      "       BasicBlock-11           [-1, 64, 56, 56]               0\n",
      "           Conv2d-12           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-13           [-1, 64, 56, 56]             128\n",
      "             ReLU-14           [-1, 64, 56, 56]               0\n",
      "           Conv2d-15           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-16           [-1, 64, 56, 56]             128\n",
      "             ReLU-17           [-1, 64, 56, 56]               0\n",
      "       BasicBlock-18           [-1, 64, 56, 56]               0\n",
      "           Conv2d-19           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-20           [-1, 64, 56, 56]             128\n",
      "             ReLU-21           [-1, 64, 56, 56]               0\n",
      "           Conv2d-22           [-1, 64, 56, 56]          36,864\n",
      "      BatchNorm2d-23           [-1, 64, 56, 56]             128\n",
      "             ReLU-24           [-1, 64, 56, 56]               0\n",
      "       BasicBlock-25           [-1, 64, 56, 56]               0\n",
      "           Conv2d-26          [-1, 128, 28, 28]          73,728\n",
      "      BatchNorm2d-27          [-1, 128, 28, 28]             256\n",
      "             ReLU-28          [-1, 128, 28, 28]               0\n",
      "           Conv2d-29          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-30          [-1, 128, 28, 28]             256\n",
      "           Conv2d-31          [-1, 128, 28, 28]           8,192\n",
      "      BatchNorm2d-32          [-1, 128, 28, 28]             256\n",
      "             ReLU-33          [-1, 128, 28, 28]               0\n",
      "       BasicBlock-34          [-1, 128, 28, 28]               0\n",
      "           Conv2d-35          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-36          [-1, 128, 28, 28]             256\n",
      "             ReLU-37          [-1, 128, 28, 28]               0\n",
      "           Conv2d-38          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-39          [-1, 128, 28, 28]             256\n",
      "             ReLU-40          [-1, 128, 28, 28]               0\n",
      "       BasicBlock-41          [-1, 128, 28, 28]               0\n",
      "           Conv2d-42          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-43          [-1, 128, 28, 28]             256\n",
      "             ReLU-44          [-1, 128, 28, 28]               0\n",
      "           Conv2d-45          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-46          [-1, 128, 28, 28]             256\n",
      "             ReLU-47          [-1, 128, 28, 28]               0\n",
      "       BasicBlock-48          [-1, 128, 28, 28]               0\n",
      "           Conv2d-49          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-50          [-1, 128, 28, 28]             256\n",
      "             ReLU-51          [-1, 128, 28, 28]               0\n",
      "           Conv2d-52          [-1, 128, 28, 28]         147,456\n",
      "      BatchNorm2d-53          [-1, 128, 28, 28]             256\n",
      "             ReLU-54          [-1, 128, 28, 28]               0\n",
      "       BasicBlock-55          [-1, 128, 28, 28]               0\n",
      "           Conv2d-56          [-1, 256, 14, 14]         294,912\n",
      "      BatchNorm2d-57          [-1, 256, 14, 14]             512\n",
      "             ReLU-58          [-1, 256, 14, 14]               0\n",
      "           Conv2d-59          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-60          [-1, 256, 14, 14]             512\n",
      "           Conv2d-61          [-1, 256, 14, 14]          32,768\n",
      "      BatchNorm2d-62          [-1, 256, 14, 14]             512\n",
      "             ReLU-63          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-64          [-1, 256, 14, 14]               0\n",
      "           Conv2d-65          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-66          [-1, 256, 14, 14]             512\n",
      "             ReLU-67          [-1, 256, 14, 14]               0\n",
      "           Conv2d-68          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-69          [-1, 256, 14, 14]             512\n",
      "             ReLU-70          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-71          [-1, 256, 14, 14]               0\n",
      "           Conv2d-72          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-73          [-1, 256, 14, 14]             512\n",
      "             ReLU-74          [-1, 256, 14, 14]               0\n",
      "           Conv2d-75          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-76          [-1, 256, 14, 14]             512\n",
      "             ReLU-77          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-78          [-1, 256, 14, 14]               0\n",
      "           Conv2d-79          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-80          [-1, 256, 14, 14]             512\n",
      "             ReLU-81          [-1, 256, 14, 14]               0\n",
      "           Conv2d-82          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-83          [-1, 256, 14, 14]             512\n",
      "             ReLU-84          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-85          [-1, 256, 14, 14]               0\n",
      "           Conv2d-86          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-87          [-1, 256, 14, 14]             512\n",
      "             ReLU-88          [-1, 256, 14, 14]               0\n",
      "           Conv2d-89          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-90          [-1, 256, 14, 14]             512\n",
      "             ReLU-91          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-92          [-1, 256, 14, 14]               0\n",
      "           Conv2d-93          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-94          [-1, 256, 14, 14]             512\n",
      "             ReLU-95          [-1, 256, 14, 14]               0\n",
      "           Conv2d-96          [-1, 256, 14, 14]         589,824\n",
      "      BatchNorm2d-97          [-1, 256, 14, 14]             512\n",
      "             ReLU-98          [-1, 256, 14, 14]               0\n",
      "       BasicBlock-99          [-1, 256, 14, 14]               0\n",
      "          Conv2d-100            [-1, 512, 7, 7]       1,179,648\n",
      "     BatchNorm2d-101            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-102            [-1, 512, 7, 7]               0\n",
      "          Conv2d-103            [-1, 512, 7, 7]       2,359,296\n",
      "     BatchNorm2d-104            [-1, 512, 7, 7]           1,024\n",
      "          Conv2d-105            [-1, 512, 7, 7]         131,072\n",
      "     BatchNorm2d-106            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-107            [-1, 512, 7, 7]               0\n",
      "      BasicBlock-108            [-1, 512, 7, 7]               0\n",
      "          Conv2d-109            [-1, 512, 7, 7]       2,359,296\n",
      "     BatchNorm2d-110            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-111            [-1, 512, 7, 7]               0\n",
      "          Conv2d-112            [-1, 512, 7, 7]       2,359,296\n",
      "     BatchNorm2d-113            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-114            [-1, 512, 7, 7]               0\n",
      "      BasicBlock-115            [-1, 512, 7, 7]               0\n",
      "          Conv2d-116            [-1, 512, 7, 7]       2,359,296\n",
      "     BatchNorm2d-117            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-118            [-1, 512, 7, 7]               0\n",
      "          Conv2d-119            [-1, 512, 7, 7]       2,359,296\n",
      "     BatchNorm2d-120            [-1, 512, 7, 7]           1,024\n",
      "            ReLU-121            [-1, 512, 7, 7]               0\n",
      "      BasicBlock-122            [-1, 512, 7, 7]               0\n",
      "AdaptiveMaxPool2d-123            [-1, 512, 1, 1]               0\n",
      "AdaptiveAvgPool2d-124            [-1, 512, 1, 1]               0\n",
      "AdaptiveConcatPool2d-125           [-1, 1024, 1, 1]               0\n",
      "         Flatten-126                 [-1, 1024]               0\n",
      "     BatchNorm1d-127                 [-1, 1024]           2,048\n",
      "         Dropout-128                 [-1, 1024]               0\n",
      "          Linear-129                  [-1, 512]         524,800\n",
      "            ReLU-130                  [-1, 512]               0\n",
      "     BatchNorm1d-131                  [-1, 512]           1,024\n",
      "         Dropout-132                  [-1, 512]               0\n",
      "          Linear-133                   [-1, 10]           5,130\n",
      "================================================================\n",
      "Total params: 21,817,674\n",
      "Trainable params: 550,026\n",
      "Non-trainable params: 21,267,648\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.57\n",
      "Forward/backward pass size (MB): 96.33\n",
      "Params size (MB): 83.23\n",
      "Estimated Total Size (MB): 180.13\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(mdl, (3, 224, 224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([64, 64, 112, 112])\n",
      "torch.Size([64, 64, 56, 56])\n",
      "torch.Size([64, 128, 28, 28])\n",
      "torch.Size([64, 256, 14, 14])\n",
      "torch.Size([64, 512, 7, 7])\n"
     ]
    }
   ],
   "source": [
    "x, y = next(iter(data.train_dl))\n",
    "x = torch.autograd.Variable(x).cuda()\n",
    "out1 = mdl(x)\n",
    "out2 = net(x)\n",
    "for i in range(5) : \n",
    "    print(sf[i].features.shape)\n",
    "    assert(sf[i].features.shape == sf2[i].features.shape)\n",
    "del x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "tes",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-17-d6f9c8c7d2ad>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_ds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.conda/envs/fastai/lib/python3.7/site-packages/fastai/data_block.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, k)\u001b[0m\n\u001b[1;32m    638\u001b[0m         \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    639\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mres\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 640\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mAttributeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    641\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    642\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__setstate__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: tes"
     ]
    }
   ],
   "source": [
    "data.train_ds."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Stage-wise training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "COMET INFO: ----------------------------\n",
      "COMET INFO: Comet.ml Experiment Summary:\n",
      "COMET INFO:   Data:\n",
      "COMET INFO:     url: https://www.comet.ml/akshaykvnit/kd0/2c7d06194a994c3daec1f9da8cda7de1\n",
      "COMET INFO:   Metrics [count] (min, max):\n",
      "COMET INFO:     loss [2010]                   : (0.02099142037332058, 0.8028264045715332)\n",
      "COMET INFO:     sys.gpu.0.free_memory [89]    : (7528251392.0, 7597457408.0)\n",
      "COMET INFO:     sys.gpu.0.gpu_utilization [89]: (0.0, 100.0)\n",
      "COMET INFO:     sys.gpu.0.total_memory        : (11721506816.0, 11721506816.0)\n",
      "COMET INFO:     sys.gpu.0.used_memory [89]    : (4124049408.0, 4193255424.0)\n",
      "COMET INFO:     sys.gpu.1.free_memory [89]    : (6221987840.0, 6221987840.0)\n",
      "COMET INFO:     sys.gpu.1.gpu_utilization [89]: (0.0, 0.0)\n",
      "COMET INFO:     sys.gpu.1.total_memory        : (6233391104.0, 6233391104.0)\n",
      "COMET INFO:     sys.gpu.1.used_memory [89]    : (11403264.0, 11403264.0)\n",
      "COMET INFO:     train_loss [100]              : (0.022360381211585074, 0.43217234439517727)\n",
      "COMET INFO:     val_loss [100]                : (0.021705504972487688, 0.2769218385219574)\n",
      "COMET INFO: ----------------------------\n",
      "COMET INFO: Experiment is live on comet.ml https://www.comet.ml/akshaykvnit/kd0/67ce178b375745afbf4d5e2093e008d6\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch =  0  step =  50  of total steps  201  loss =  1.3363007307052612\n",
      "epoch =  0  step =  100  of total steps  201  loss =  1.5924104452133179\n",
      "epoch =  0  step =  150  of total steps  201  loss =  2.063514471054077\n",
      "epoch =  0  step =  200  of total steps  201  loss =  0.6800873279571533\n",
      "epoch :  1  /  100  | TL :  1.4434532529086024  | VL :  1.098916418850422\n",
      "saving model\n",
      "epoch =  1  step =  50  of total steps  201  loss =  1.4276881217956543\n",
      "epoch =  1  step =  100  of total steps  201  loss =  0.7155402898788452\n",
      "epoch =  1  step =  150  of total steps  201  loss =  2.5248448848724365\n",
      "epoch =  1  step =  200  of total steps  201  loss =  1.4483039379119873\n",
      "epoch :  2  /  100  | TL :  1.0154912258262065  | VL :  0.9642607606947422\n",
      "saving model\n",
      "epoch =  2  step =  50  of total steps  201  loss =  0.5965914130210876\n",
      "epoch =  2  step =  100  of total steps  201  loss =  0.5475767254829407\n",
      "epoch =  2  step =  150  of total steps  201  loss =  0.8009404540061951\n",
      "epoch =  2  step =  200  of total steps  201  loss =  0.4995229244232178\n",
      "epoch :  3  /  100  | TL :  0.9318587109817201  | VL :  0.8207416906952858\n",
      "saving model\n",
      "epoch =  3  step =  50  of total steps  201  loss =  3.0689992904663086\n",
      "epoch =  3  step =  100  of total steps  201  loss =  0.4608438014984131\n",
      "epoch =  3  step =  150  of total steps  201  loss =  0.6237032413482666\n",
      "epoch =  3  step =  200  of total steps  201  loss =  2.1428074836730957\n",
      "epoch :  4  /  100  | TL :  0.8621557179968156  | VL :  0.7825270257890224\n",
      "saving model\n",
      "epoch =  4  step =  50  of total steps  201  loss =  1.2216565608978271\n",
      "epoch =  4  step =  100  of total steps  201  loss =  0.6518152952194214\n",
      "epoch =  4  step =  150  of total steps  201  loss =  0.4315777122974396\n",
      "epoch =  4  step =  200  of total steps  201  loss =  0.41962507367134094\n",
      "epoch :  5  /  100  | TL :  0.8007225478763011  | VL :  0.6957807466387749\n",
      "saving model\n",
      "epoch =  5  step =  50  of total steps  201  loss =  1.599657654762268\n",
      "epoch =  5  step =  100  of total steps  201  loss =  0.4515252411365509\n",
      "epoch =  5  step =  150  of total steps  201  loss =  0.45117059350013733\n",
      "epoch =  5  step =  200  of total steps  201  loss =  0.8993934988975525\n",
      "epoch :  6  /  100  | TL :  0.7105142744026374  | VL :  0.6934961639344692\n",
      "saving model\n",
      "epoch =  6  step =  50  of total steps  201  loss =  0.3790135979652405\n",
      "epoch =  6  step =  100  of total steps  201  loss =  0.36572226881980896\n",
      "epoch =  6  step =  150  of total steps  201  loss =  0.3814786672592163\n",
      "epoch =  6  step =  200  of total steps  201  loss =  0.3712798058986664\n",
      "epoch :  7  /  100  | TL :  0.6844446189071408  | VL :  0.621474876999855\n",
      "saving model\n",
      "epoch =  7  step =  50  of total steps  201  loss =  0.5346654057502747\n",
      "epoch =  7  step =  100  of total steps  201  loss =  0.5427800416946411\n",
      "epoch =  7  step =  150  of total steps  201  loss =  0.3806150257587433\n",
      "epoch =  7  step =  200  of total steps  201  loss =  0.4223564565181732\n",
      "epoch :  8  /  100  | TL :  0.6580223737664483  | VL :  0.5301005467772484\n",
      "saving model\n",
      "epoch =  8  step =  50  of total steps  201  loss =  0.3115065097808838\n",
      "epoch =  8  step =  100  of total steps  201  loss =  0.43900254368782043\n",
      "epoch =  8  step =  150  of total steps  201  loss =  0.31443771719932556\n",
      "epoch =  8  step =  200  of total steps  201  loss =  0.5655535459518433\n",
      "epoch :  9  /  100  | TL :  0.6271749789738537  | VL :  0.5643847901374102\n",
      "epoch =  9  step =  50  of total steps  201  loss =  0.7301633954048157\n",
      "epoch =  9  step =  100  of total steps  201  loss =  0.2610112428665161\n",
      "epoch =  9  step =  150  of total steps  201  loss =  0.3311072289943695\n",
      "epoch =  9  step =  200  of total steps  201  loss =  0.39985889196395874\n",
      "epoch :  10  /  100  | TL :  0.5834003601649508  | VL :  0.5350679364055395\n",
      "epoch =  10  step =  50  of total steps  201  loss =  0.3054298162460327\n",
      "epoch =  10  step =  100  of total steps  201  loss =  0.40433695912361145\n",
      "epoch =  10  step =  150  of total steps  201  loss =  0.2593206763267517\n",
      "epoch =  10  step =  200  of total steps  201  loss =  1.1166980266571045\n",
      "epoch :  11  /  100  | TL :  0.5771213586798948  | VL :  0.499354287981987\n",
      "saving model\n",
      "epoch =  11  step =  50  of total steps  201  loss =  0.6296492218971252\n",
      "epoch =  11  step =  100  of total steps  201  loss =  0.41609153151512146\n",
      "epoch =  11  step =  150  of total steps  201  loss =  0.49333998560905457\n",
      "epoch =  11  step =  200  of total steps  201  loss =  0.9172167181968689\n",
      "epoch :  12  /  100  | TL :  0.5431875599854028  | VL :  0.47174666076898575\n",
      "saving model\n",
      "epoch =  12  step =  50  of total steps  201  loss =  0.8996605277061462\n",
      "epoch =  12  step =  100  of total steps  201  loss =  0.27987998723983765\n",
      "epoch =  12  step =  150  of total steps  201  loss =  0.4144595265388489\n",
      "epoch =  12  step =  200  of total steps  201  loss =  0.3356155753135681\n",
      "epoch :  13  /  100  | TL :  0.5368926181128962  | VL :  0.40158725902438164\n",
      "saving model\n",
      "epoch =  13  step =  50  of total steps  201  loss =  1.3206111192703247\n",
      "epoch =  13  step =  100  of total steps  201  loss =  0.5806638598442078\n",
      "epoch =  13  step =  150  of total steps  201  loss =  0.3863615095615387\n",
      "epoch =  13  step =  200  of total steps  201  loss =  0.36737939715385437\n",
      "epoch :  14  /  100  | TL :  0.5424304147116581  | VL :  0.452036764472723\n",
      "epoch =  14  step =  50  of total steps  201  loss =  0.3468874394893646\n",
      "epoch =  14  step =  100  of total steps  201  loss =  0.4083949327468872\n",
      "epoch =  14  step =  150  of total steps  201  loss =  0.3102068603038788\n",
      "epoch =  14  step =  200  of total steps  201  loss =  0.3953082859516144\n",
      "epoch :  15  /  100  | TL :  0.48443579495842776  | VL :  0.4319888213649392\n",
      "epoch =  15  step =  50  of total steps  201  loss =  0.5770427584648132\n",
      "epoch =  15  step =  100  of total steps  201  loss =  0.40122705698013306\n",
      "epoch =  15  step =  150  of total steps  201  loss =  0.31056612730026245\n",
      "epoch =  15  step =  200  of total steps  201  loss =  0.280568927526474\n",
      "epoch :  16  /  100  | TL :  0.49976452606827465  | VL :  0.41279782820492983\n",
      "epoch =  16  step =  50  of total steps  201  loss =  0.5887799859046936\n",
      "epoch =  16  step =  100  of total steps  201  loss =  0.22548848390579224\n",
      "epoch =  16  step =  150  of total steps  201  loss =  0.5593953728675842\n",
      "epoch =  16  step =  200  of total steps  201  loss =  0.21698534488677979\n",
      "epoch :  17  /  100  | TL :  0.49047917105368716  | VL :  0.35702727548778057\n",
      "saving model\n",
      "epoch =  17  step =  50  of total steps  201  loss =  0.24954600632190704\n",
      "epoch =  17  step =  100  of total steps  201  loss =  0.24545136094093323\n",
      "epoch =  17  step =  150  of total steps  201  loss =  0.4562869966030121\n",
      "epoch =  17  step =  200  of total steps  201  loss =  0.542346179485321\n",
      "epoch :  18  /  100  | TL :  0.4752224958061579  | VL :  0.3814458139240742\n",
      "epoch =  18  step =  50  of total steps  201  loss =  0.370288223028183\n",
      "epoch =  18  step =  100  of total steps  201  loss =  0.473908394575119\n",
      "epoch =  18  step =  150  of total steps  201  loss =  0.38848158717155457\n",
      "epoch =  18  step =  200  of total steps  201  loss =  0.7625256776809692\n",
      "epoch :  19  /  100  | TL :  0.4684422145880277  | VL :  0.3398084379732609\n",
      "saving model\n",
      "epoch =  19  step =  50  of total steps  201  loss =  0.2899431586265564\n",
      "epoch =  19  step =  100  of total steps  201  loss =  0.6636096835136414\n",
      "epoch =  19  step =  150  of total steps  201  loss =  0.3810896873474121\n",
      "epoch =  19  step =  200  of total steps  201  loss =  0.4834587574005127\n",
      "epoch :  20  /  100  | TL :  0.4428754681674995  | VL :  0.34031077940016985\n",
      "epoch =  20  step =  50  of total steps  201  loss =  0.5890750288963318\n",
      "epoch =  20  step =  100  of total steps  201  loss =  0.21237392723560333\n",
      "epoch =  20  step =  150  of total steps  201  loss =  0.2612070143222809\n",
      "epoch =  20  step =  200  of total steps  201  loss =  0.6812886595726013\n",
      "epoch :  21  /  100  | TL :  0.4370809672632028  | VL :  0.3479454657062888\n",
      "epoch =  21  step =  50  of total steps  201  loss =  0.26334598660469055\n",
      "epoch =  21  step =  100  of total steps  201  loss =  0.2590363025665283\n",
      "epoch =  21  step =  150  of total steps  201  loss =  1.6386595964431763\n",
      "epoch =  21  step =  200  of total steps  201  loss =  0.631398618221283\n",
      "epoch :  22  /  100  | TL :  0.43024412086650504  | VL :  0.33866226859390736\n",
      "saving model\n",
      "epoch =  22  step =  50  of total steps  201  loss =  0.5630598068237305\n",
      "epoch =  22  step =  100  of total steps  201  loss =  0.8315630555152893\n",
      "epoch =  22  step =  150  of total steps  201  loss =  0.34214290976524353\n",
      "epoch =  22  step =  200  of total steps  201  loss =  0.28183820843696594\n",
      "epoch :  23  /  100  | TL :  0.4253652346371418  | VL :  0.3214079877361655\n",
      "saving model\n",
      "epoch =  23  step =  50  of total steps  201  loss =  0.1652694046497345\n",
      "epoch =  23  step =  100  of total steps  201  loss =  0.2229609191417694\n",
      "epoch =  23  step =  150  of total steps  201  loss =  0.20646585524082184\n",
      "epoch =  23  step =  200  of total steps  201  loss =  0.26068732142448425\n",
      "epoch :  24  /  100  | TL :  0.44746079040107445  | VL :  0.33735080156475306\n",
      "epoch =  24  step =  50  of total steps  201  loss =  0.5981717109680176\n",
      "epoch =  24  step =  100  of total steps  201  loss =  0.44476038217544556\n",
      "epoch =  24  step =  150  of total steps  201  loss =  0.24599821865558624\n",
      "epoch =  24  step =  200  of total steps  201  loss =  0.26418086886405945\n",
      "epoch :  25  /  100  | TL :  0.4084508720025494  | VL :  0.3122872970998287\n",
      "saving model\n",
      "epoch =  25  step =  50  of total steps  201  loss =  0.25096553564071655\n",
      "epoch =  25  step =  100  of total steps  201  loss =  0.20496909320354462\n",
      "epoch =  25  step =  150  of total steps  201  loss =  0.3300163149833679\n",
      "epoch =  25  step =  200  of total steps  201  loss =  0.21935221552848816\n",
      "epoch :  26  /  100  | TL :  0.4073682907060604  | VL :  0.31151882372796535\n",
      "saving model\n",
      "epoch =  26  step =  50  of total steps  201  loss =  0.3781047761440277\n",
      "epoch =  26  step =  100  of total steps  201  loss =  0.23152758181095123\n",
      "epoch =  26  step =  150  of total steps  201  loss =  0.31270262598991394\n",
      "epoch =  26  step =  200  of total steps  201  loss =  0.3461587429046631\n",
      "epoch :  27  /  100  | TL :  0.3908444852081697  | VL :  0.3363481452688575\n",
      "epoch =  27  step =  50  of total steps  201  loss =  0.7449565529823303\n",
      "epoch =  27  step =  100  of total steps  201  loss =  0.5084343552589417\n",
      "epoch =  27  step =  150  of total steps  201  loss =  0.6685928106307983\n",
      "epoch =  27  step =  200  of total steps  201  loss =  0.2970266044139862\n",
      "epoch :  28  /  100  | TL :  0.39306186703010576  | VL :  0.2896931767463684\n",
      "saving model\n",
      "epoch =  28  step =  50  of total steps  201  loss =  0.21631470322608948\n",
      "epoch =  28  step =  100  of total steps  201  loss =  0.6003698706626892\n",
      "epoch =  28  step =  150  of total steps  201  loss =  0.19499969482421875\n",
      "epoch =  28  step =  200  of total steps  201  loss =  0.2642862796783447\n",
      "epoch :  29  /  100  | TL :  0.40497175882111736  | VL :  0.3147032596170902\n",
      "epoch =  29  step =  50  of total steps  201  loss =  0.15894146263599396\n",
      "epoch =  29  step =  100  of total steps  201  loss =  0.6727049350738525\n",
      "epoch =  29  step =  150  of total steps  201  loss =  0.18029740452766418\n",
      "epoch =  29  step =  200  of total steps  201  loss =  0.339262455701828\n",
      "epoch :  30  /  100  | TL :  0.3854239663822734  | VL :  0.28547164145857096\n",
      "saving model\n",
      "epoch =  30  step =  50  of total steps  201  loss =  0.3958902060985565\n",
      "epoch =  30  step =  100  of total steps  201  loss =  0.1565287858247757\n",
      "epoch =  30  step =  150  of total steps  201  loss =  0.2896653115749359\n",
      "epoch =  30  step =  200  of total steps  201  loss =  0.3154236078262329\n",
      "epoch :  31  /  100  | TL :  0.366778151461141  | VL :  0.29818573873490095\n",
      "epoch =  31  step =  50  of total steps  201  loss =  0.5270338654518127\n",
      "epoch =  31  step =  100  of total steps  201  loss =  0.4053179621696472\n",
      "epoch =  31  step =  150  of total steps  201  loss =  0.3229493796825409\n",
      "epoch =  31  step =  200  of total steps  201  loss =  0.2927687168121338\n",
      "epoch :  32  /  100  | TL :  0.3467764744711159  | VL :  0.3004876817576587\n",
      "epoch =  32  step =  50  of total steps  201  loss =  0.21472381055355072\n",
      "epoch =  32  step =  100  of total steps  201  loss =  0.2093142867088318\n",
      "epoch =  32  step =  150  of total steps  201  loss =  0.3176526427268982\n",
      "epoch =  32  step =  200  of total steps  201  loss =  0.3022531270980835\n",
      "epoch :  33  /  100  | TL :  0.381299135326153  | VL :  0.27031920850276947\n",
      "saving model\n",
      "epoch =  33  step =  50  of total steps  201  loss =  0.1594221442937851\n",
      "epoch =  33  step =  100  of total steps  201  loss =  0.24892954528331757\n",
      "epoch =  33  step =  150  of total steps  201  loss =  0.21551889181137085\n",
      "epoch =  33  step =  200  of total steps  201  loss =  0.35918721556663513\n",
      "epoch :  34  /  100  | TL :  0.36008523256327973  | VL :  0.2838273346424103\n",
      "epoch =  34  step =  50  of total steps  201  loss =  0.44498690962791443\n",
      "epoch =  34  step =  100  of total steps  201  loss =  0.17685286700725555\n",
      "epoch =  34  step =  150  of total steps  201  loss =  0.21400302648544312\n",
      "epoch =  34  step =  200  of total steps  201  loss =  0.2627808749675751\n",
      "epoch :  35  /  100  | TL :  0.3754857507955969  | VL :  0.28530052956193686\n",
      "epoch =  35  step =  50  of total steps  201  loss =  0.18569758534431458\n",
      "epoch =  35  step =  100  of total steps  201  loss =  0.17292045056819916\n",
      "epoch =  35  step =  150  of total steps  201  loss =  0.20305468142032623\n",
      "epoch =  35  step =  200  of total steps  201  loss =  0.1883557289838791\n",
      "epoch :  36  /  100  | TL :  0.35727623082808596  | VL :  0.2837277967482805\n",
      "epoch =  36  step =  50  of total steps  201  loss =  0.6409269571304321\n",
      "epoch =  36  step =  100  of total steps  201  loss =  0.8131771087646484\n",
      "epoch =  36  step =  150  of total steps  201  loss =  0.1637154221534729\n",
      "epoch =  36  step =  200  of total steps  201  loss =  0.5723938345909119\n",
      "epoch :  37  /  100  | TL :  0.34782043180952027  | VL :  0.2924207993783057\n",
      "epoch =  37  step =  50  of total steps  201  loss =  0.14842815697193146\n",
      "epoch =  37  step =  100  of total steps  201  loss =  0.24217888712882996\n",
      "epoch =  37  step =  150  of total steps  201  loss =  0.21252775192260742\n",
      "epoch =  37  step =  200  of total steps  201  loss =  0.2758912444114685\n",
      "epoch :  38  /  100  | TL :  0.3760882421809049  | VL :  0.2914697336964309\n",
      "epoch =  38  step =  50  of total steps  201  loss =  0.15227453410625458\n",
      "epoch =  38  step =  100  of total steps  201  loss =  1.7099878787994385\n",
      "epoch =  38  step =  150  of total steps  201  loss =  0.12185534089803696\n",
      "epoch =  38  step =  200  of total steps  201  loss =  0.3224309980869293\n",
      "epoch :  39  /  100  | TL :  0.3389570086230686  | VL :  0.2755112564191222\n",
      "epoch =  39  step =  50  of total steps  201  loss =  0.1993829309940338\n",
      "epoch =  39  step =  100  of total steps  201  loss =  0.25041723251342773\n",
      "epoch =  39  step =  150  of total steps  201  loss =  0.19249781966209412\n",
      "epoch =  39  step =  200  of total steps  201  loss =  0.2376260906457901\n",
      "epoch :  40  /  100  | TL :  0.3566507949535526  | VL :  0.2756520230323076\n",
      "epoch =  40  step =  50  of total steps  201  loss =  0.22907130420207977\n",
      "epoch =  40  step =  100  of total steps  201  loss =  0.39299866557121277\n",
      "epoch =  40  step =  150  of total steps  201  loss =  0.18733592331409454\n",
      "epoch =  40  step =  200  of total steps  201  loss =  0.18313150107860565\n",
      "epoch :  41  /  100  | TL :  0.3580860382140572  | VL :  0.2807845510542393\n",
      "epoch =  41  step =  50  of total steps  201  loss =  0.2353949248790741\n",
      "epoch =  41  step =  100  of total steps  201  loss =  0.16716144979000092\n",
      "epoch =  41  step =  150  of total steps  201  loss =  0.650130569934845\n",
      "epoch =  41  step =  200  of total steps  201  loss =  0.2890647053718567\n",
      "epoch :  42  /  100  | TL :  0.3435586250540036  | VL :  0.2623562845401466\n",
      "saving model\n",
      "epoch =  42  step =  50  of total steps  201  loss =  0.32361283898353577\n",
      "epoch =  42  step =  100  of total steps  201  loss =  0.3001026511192322\n",
      "epoch =  42  step =  150  of total steps  201  loss =  0.3207901418209076\n",
      "epoch =  42  step =  200  of total steps  201  loss =  0.47192955017089844\n",
      "epoch :  43  /  100  | TL :  0.34177714434280915  | VL :  0.28764524264261127\n",
      "epoch =  43  step =  50  of total steps  201  loss =  0.1798098385334015\n",
      "epoch =  43  step =  100  of total steps  201  loss =  0.21029500663280487\n",
      "epoch =  43  step =  150  of total steps  201  loss =  0.3249070346355438\n",
      "epoch =  43  step =  200  of total steps  201  loss =  0.8761282563209534\n",
      "epoch :  44  /  100  | TL :  0.34140195026623077  | VL :  0.2816605567932129\n",
      "epoch =  44  step =  50  of total steps  201  loss =  0.2770947515964508\n",
      "epoch =  44  step =  100  of total steps  201  loss =  0.6246691942214966\n",
      "epoch =  44  step =  150  of total steps  201  loss =  0.27642592787742615\n",
      "epoch =  44  step =  200  of total steps  201  loss =  0.21309120953083038\n",
      "epoch :  45  /  100  | TL :  0.35953214475468026  | VL :  0.281871790997684\n",
      "epoch =  45  step =  50  of total steps  201  loss =  0.319578617811203\n",
      "epoch =  45  step =  100  of total steps  201  loss =  0.18373407423496246\n",
      "epoch =  45  step =  150  of total steps  201  loss =  0.12758393585681915\n",
      "epoch =  45  step =  200  of total steps  201  loss =  0.21284501254558563\n",
      "epoch :  46  /  100  | TL :  0.33879356775710834  | VL :  0.25454444997012615\n",
      "saving model\n",
      "epoch =  46  step =  50  of total steps  201  loss =  0.27768468856811523\n",
      "epoch =  46  step =  100  of total steps  201  loss =  0.4010154604911804\n",
      "epoch =  46  step =  150  of total steps  201  loss =  0.1628320813179016\n",
      "epoch =  46  step =  200  of total steps  201  loss =  0.15422652661800385\n",
      "epoch :  47  /  100  | TL :  0.33997564284659143  | VL :  0.25750750536099076\n",
      "epoch =  47  step =  50  of total steps  201  loss =  0.853797435760498\n",
      "epoch =  47  step =  100  of total steps  201  loss =  0.9594241380691528\n",
      "epoch =  47  step =  150  of total steps  201  loss =  0.22169971466064453\n",
      "epoch =  47  step =  200  of total steps  201  loss =  1.063235878944397\n",
      "epoch :  48  /  100  | TL :  0.3430310811347036  | VL :  0.2552833925001323\n",
      "epoch =  48  step =  50  of total steps  201  loss =  0.1580798625946045\n",
      "epoch =  48  step =  100  of total steps  201  loss =  0.2855444848537445\n",
      "epoch =  48  step =  150  of total steps  201  loss =  0.21979348361492157\n",
      "epoch =  48  step =  200  of total steps  201  loss =  0.12137044966220856\n",
      "epoch :  49  /  100  | TL :  0.31929998652110647  | VL :  0.26289414893835783\n",
      "epoch =  49  step =  50  of total steps  201  loss =  0.3262864649295807\n",
      "epoch =  49  step =  100  of total steps  201  loss =  0.9740334153175354\n",
      "epoch =  49  step =  150  of total steps  201  loss =  0.25047481060028076\n",
      "epoch =  49  step =  200  of total steps  201  loss =  0.232718363404274\n",
      "epoch :  50  /  100  | TL :  0.34087478751270334  | VL :  0.2679750267416239\n",
      "epoch =  50  step =  50  of total steps  201  loss =  0.20686788856983185\n",
      "epoch =  50  step =  100  of total steps  201  loss =  1.8231375217437744\n",
      "epoch =  50  step =  150  of total steps  201  loss =  0.17962653934955597\n",
      "epoch =  50  step =  200  of total steps  201  loss =  1.2918907403945923\n",
      "epoch :  51  /  100  | TL :  0.34228081541571453  | VL :  0.25287877209484577\n",
      "saving model\n",
      "epoch =  51  step =  50  of total steps  201  loss =  0.24822650849819183\n",
      "epoch =  51  step =  100  of total steps  201  loss =  1.060990810394287\n",
      "epoch =  51  step =  150  of total steps  201  loss =  0.2228943556547165\n",
      "epoch =  51  step =  200  of total steps  201  loss =  0.5971835851669312\n",
      "epoch :  52  /  100  | TL :  0.3336658736514808  | VL :  0.28283419366925955\n",
      "epoch =  52  step =  50  of total steps  201  loss =  0.1823933869600296\n",
      "epoch =  52  step =  100  of total steps  201  loss =  0.22288013994693756\n",
      "epoch =  52  step =  150  of total steps  201  loss =  0.3394603133201599\n",
      "epoch =  52  step =  200  of total steps  201  loss =  0.2549245357513428\n",
      "epoch :  53  /  100  | TL :  0.33098161391061337  | VL :  0.25668712286278605\n",
      "epoch =  53  step =  50  of total steps  201  loss =  0.36427751183509827\n",
      "epoch =  53  step =  100  of total steps  201  loss =  0.22846123576164246\n",
      "epoch =  53  step =  150  of total steps  201  loss =  0.13803847134113312\n",
      "epoch =  53  step =  200  of total steps  201  loss =  0.21120736002922058\n",
      "epoch :  54  /  100  | TL :  0.32385031738091463  | VL :  0.2471815780736506\n",
      "saving model\n",
      "epoch =  54  step =  50  of total steps  201  loss =  0.1831803321838379\n",
      "epoch =  54  step =  100  of total steps  201  loss =  0.4763934314250946\n",
      "epoch =  54  step =  150  of total steps  201  loss =  0.2939941883087158\n",
      "epoch =  54  step =  200  of total steps  201  loss =  0.2658322751522064\n",
      "epoch :  55  /  100  | TL :  0.31914823465234604  | VL :  0.2446282897144556\n",
      "saving model\n",
      "epoch =  55  step =  50  of total steps  201  loss =  0.3222612142562866\n",
      "epoch =  55  step =  100  of total steps  201  loss =  0.21017228066921234\n",
      "epoch =  55  step =  150  of total steps  201  loss =  0.1848582625389099\n",
      "epoch =  55  step =  200  of total steps  201  loss =  0.8145236968994141\n",
      "epoch :  56  /  100  | TL :  0.33497415557133026  | VL :  0.2742465604096651\n",
      "epoch =  56  step =  50  of total steps  201  loss =  0.18441718816757202\n",
      "epoch =  56  step =  100  of total steps  201  loss =  0.2542833685874939\n",
      "epoch =  56  step =  150  of total steps  201  loss =  0.2955993413925171\n",
      "epoch =  56  step =  200  of total steps  201  loss =  0.2531839907169342\n",
      "epoch :  57  /  100  | TL :  0.349257964622322  | VL :  0.24810365634039044\n",
      "epoch =  57  step =  50  of total steps  201  loss =  0.4104498624801636\n",
      "epoch =  57  step =  100  of total steps  201  loss =  0.6196351051330566\n",
      "epoch =  57  step =  150  of total steps  201  loss =  0.42129990458488464\n",
      "epoch =  57  step =  200  of total steps  201  loss =  0.4629940092563629\n",
      "epoch :  58  /  100  | TL :  0.3045980268672331  | VL :  0.2583562806248665\n",
      "epoch =  58  step =  50  of total steps  201  loss =  0.3582005202770233\n",
      "epoch =  58  step =  100  of total steps  201  loss =  0.2952069640159607\n",
      "epoch =  58  step =  150  of total steps  201  loss =  0.14940683543682098\n",
      "epoch =  58  step =  200  of total steps  201  loss =  0.3466206192970276\n",
      "epoch :  59  /  100  | TL :  0.33016527299560716  | VL :  0.26817572861909866\n",
      "epoch =  59  step =  50  of total steps  201  loss =  0.16036050021648407\n",
      "epoch =  59  step =  100  of total steps  201  loss =  0.29583507776260376\n",
      "epoch =  59  step =  150  of total steps  201  loss =  0.17785495519638062\n",
      "epoch =  59  step =  200  of total steps  201  loss =  0.19085447490215302\n",
      "epoch :  60  /  100  | TL :  0.3388544599808271  | VL :  0.26932147052139044\n",
      "epoch =  60  step =  50  of total steps  201  loss =  0.23932471871376038\n",
      "epoch =  60  step =  100  of total steps  201  loss =  0.30463099479675293\n",
      "epoch =  60  step =  150  of total steps  201  loss =  0.20089861750602722\n",
      "epoch =  60  step =  200  of total steps  201  loss =  0.3214413821697235\n",
      "epoch :  61  /  100  | TL :  0.31674468424634555  | VL :  0.26212842809036374\n",
      "epoch =  61  step =  50  of total steps  201  loss =  0.22611333429813385\n",
      "epoch =  61  step =  100  of total steps  201  loss =  0.19795522093772888\n",
      "epoch =  61  step =  150  of total steps  201  loss =  0.20129217207431793\n",
      "epoch =  61  step =  200  of total steps  201  loss =  0.8739120960235596\n",
      "epoch :  62  /  100  | TL :  0.32827377916123734  | VL :  0.26649915473535657\n",
      "epoch =  62  step =  50  of total steps  201  loss =  0.26990658044815063\n",
      "epoch =  62  step =  100  of total steps  201  loss =  0.20525354146957397\n",
      "epoch =  62  step =  150  of total steps  201  loss =  0.2374250590801239\n",
      "epoch =  62  step =  200  of total steps  201  loss =  0.1557898074388504\n",
      "epoch :  63  /  100  | TL :  0.3015441315535882  | VL :  0.25010826624929905\n",
      "epoch =  63  step =  50  of total steps  201  loss =  0.25975146889686584\n",
      "epoch =  63  step =  100  of total steps  201  loss =  0.20849862694740295\n",
      "epoch =  63  step =  150  of total steps  201  loss =  0.20462629199028015\n",
      "epoch =  63  step =  200  of total steps  201  loss =  0.22674210369586945\n",
      "epoch :  64  /  100  | TL :  0.31528659305762297  | VL :  0.257613786496222\n",
      "epoch =  64  step =  50  of total steps  201  loss =  0.26988258957862854\n",
      "epoch =  64  step =  100  of total steps  201  loss =  0.22434557974338531\n",
      "epoch =  64  step =  150  of total steps  201  loss =  0.3786505162715912\n",
      "epoch =  64  step =  200  of total steps  201  loss =  0.6616008877754211\n",
      "epoch :  65  /  100  | TL :  0.33964296700942576  | VL :  0.2573827882297337\n",
      "epoch =  65  step =  50  of total steps  201  loss =  0.23358364403247833\n",
      "epoch =  65  step =  100  of total steps  201  loss =  0.21801809966564178\n",
      "epoch =  65  step =  150  of total steps  201  loss =  1.6644752025604248\n",
      "epoch =  65  step =  200  of total steps  201  loss =  1.2831969261169434\n",
      "epoch :  66  /  100  | TL :  0.3162688270433625  | VL :  0.23327945638448\n",
      "saving model\n",
      "epoch =  66  step =  50  of total steps  201  loss =  0.2627890408039093\n",
      "epoch =  66  step =  100  of total steps  201  loss =  0.32500091195106506\n",
      "epoch =  66  step =  150  of total steps  201  loss =  0.33325570821762085\n",
      "epoch =  66  step =  200  of total steps  201  loss =  0.20968583226203918\n",
      "epoch :  67  /  100  | TL :  0.3079227608679539  | VL :  0.24270161800086498\n",
      "epoch =  67  step =  50  of total steps  201  loss =  0.3180192708969116\n",
      "epoch =  67  step =  100  of total steps  201  loss =  0.11264127492904663\n",
      "epoch =  67  step =  150  of total steps  201  loss =  0.6683021187782288\n",
      "epoch =  67  step =  200  of total steps  201  loss =  0.2724890112876892\n",
      "epoch :  68  /  100  | TL :  0.31794304256118944  | VL :  0.23189896903932095\n",
      "saving model\n",
      "epoch =  68  step =  50  of total steps  201  loss =  0.3556596636772156\n",
      "epoch =  68  step =  100  of total steps  201  loss =  0.480337917804718\n",
      "epoch =  68  step =  150  of total steps  201  loss =  0.16038645803928375\n",
      "epoch =  68  step =  200  of total steps  201  loss =  0.21800869703292847\n",
      "epoch :  69  /  100  | TL :  0.3113496101169444  | VL :  0.24524967279285192\n",
      "epoch =  69  step =  50  of total steps  201  loss =  0.1882162094116211\n",
      "epoch =  69  step =  100  of total steps  201  loss =  0.18377360701560974\n",
      "epoch =  69  step =  150  of total steps  201  loss =  0.21767520904541016\n",
      "epoch =  69  step =  200  of total steps  201  loss =  0.31671178340911865\n",
      "epoch :  70  /  100  | TL :  0.31665259792437006  | VL :  0.25472658732905984\n",
      "epoch =  70  step =  50  of total steps  201  loss =  0.3373001515865326\n",
      "epoch =  70  step =  100  of total steps  201  loss =  0.7362411022186279\n",
      "epoch =  70  step =  150  of total steps  201  loss =  0.17569124698638916\n",
      "epoch =  70  step =  200  of total steps  201  loss =  0.4360350966453552\n",
      "epoch :  71  /  100  | TL :  0.32346629525595044  | VL :  0.24574070842936635\n",
      "epoch =  71  step =  50  of total steps  201  loss =  0.36686986684799194\n",
      "epoch =  71  step =  100  of total steps  201  loss =  0.16527342796325684\n",
      "epoch =  71  step =  150  of total steps  201  loss =  1.2192497253417969\n",
      "epoch =  71  step =  200  of total steps  201  loss =  0.2349722683429718\n",
      "epoch :  72  /  100  | TL :  0.3183494582401579  | VL :  0.24558366322889924\n",
      "epoch =  72  step =  50  of total steps  201  loss =  0.3223501145839691\n",
      "epoch =  72  step =  100  of total steps  201  loss =  0.22182802855968475\n",
      "epoch =  72  step =  150  of total steps  201  loss =  0.273536741733551\n",
      "epoch =  72  step =  200  of total steps  201  loss =  0.3395945429801941\n",
      "epoch :  73  /  100  | TL :  0.33300125309780465  | VL :  0.23200823506340384\n",
      "epoch =  73  step =  50  of total steps  201  loss =  0.3627837002277374\n",
      "epoch =  73  step =  100  of total steps  201  loss =  0.15646123886108398\n",
      "epoch =  73  step =  150  of total steps  201  loss =  0.16254113614559174\n",
      "epoch =  73  step =  200  of total steps  201  loss =  0.19798997044563293\n",
      "epoch :  74  /  100  | TL :  0.31670125114235714  | VL :  0.2569074295461178\n",
      "epoch =  74  step =  50  of total steps  201  loss =  0.34026119112968445\n",
      "epoch =  74  step =  100  of total steps  201  loss =  0.15267488360404968\n",
      "epoch =  74  step =  150  of total steps  201  loss =  0.11732587963342667\n",
      "epoch =  74  step =  200  of total steps  201  loss =  0.27421969175338745\n",
      "epoch :  75  /  100  | TL :  0.2977011043263312  | VL :  0.23997638607397676\n",
      "epoch =  75  step =  50  of total steps  201  loss =  0.19101856648921967\n",
      "epoch =  75  step =  100  of total steps  201  loss =  0.2767895758152008\n",
      "epoch =  75  step =  150  of total steps  201  loss =  0.1884261816740036\n",
      "epoch =  75  step =  200  of total steps  201  loss =  0.1732959896326065\n",
      "epoch :  76  /  100  | TL :  0.3171227295971035  | VL :  0.23299336805939674\n",
      "epoch =  76  step =  50  of total steps  201  loss =  0.3173459470272064\n",
      "epoch =  76  step =  100  of total steps  201  loss =  0.2004496306180954\n",
      "epoch =  76  step =  150  of total steps  201  loss =  0.22892290353775024\n",
      "epoch =  76  step =  200  of total steps  201  loss =  0.1292695701122284\n",
      "epoch :  77  /  100  | TL :  0.3167617641648843  | VL :  0.24462439585477114\n",
      "epoch =  77  step =  50  of total steps  201  loss =  0.33252567052841187\n",
      "epoch =  77  step =  100  of total steps  201  loss =  0.1923408806324005\n",
      "epoch =  77  step =  150  of total steps  201  loss =  0.21928946673870087\n",
      "epoch =  77  step =  200  of total steps  201  loss =  0.44104453921318054\n",
      "epoch :  78  /  100  | TL :  0.30604318955644444  | VL :  0.2535029733553529\n",
      "epoch =  78  step =  50  of total steps  201  loss =  0.6459758281707764\n",
      "epoch =  78  step =  100  of total steps  201  loss =  0.19725023210048676\n",
      "epoch =  78  step =  150  of total steps  201  loss =  0.17957842350006104\n",
      "epoch =  78  step =  200  of total steps  201  loss =  0.1677478849887848\n",
      "epoch :  79  /  100  | TL :  0.2998556426137834  | VL :  0.26208432111889124\n",
      "epoch =  79  step =  50  of total steps  201  loss =  0.3187202513217926\n",
      "epoch =  79  step =  100  of total steps  201  loss =  0.2203155755996704\n",
      "epoch =  79  step =  150  of total steps  201  loss =  0.253486692905426\n",
      "epoch =  79  step =  200  of total steps  201  loss =  0.16533389687538147\n",
      "epoch :  80  /  100  | TL :  0.3120168007873184  | VL :  0.2456403230316937\n",
      "epoch =  80  step =  50  of total steps  201  loss =  0.24230855703353882\n",
      "epoch =  80  step =  100  of total steps  201  loss =  0.23643341660499573\n",
      "epoch =  80  step =  150  of total steps  201  loss =  0.4528714418411255\n",
      "epoch =  80  step =  200  of total steps  201  loss =  0.6008489727973938\n",
      "epoch :  81  /  100  | TL :  0.2991255760118736  | VL :  0.24986435333266854\n",
      "epoch =  81  step =  50  of total steps  201  loss =  0.19289255142211914\n",
      "epoch =  81  step =  100  of total steps  201  loss =  0.15219111740589142\n",
      "epoch =  81  step =  150  of total steps  201  loss =  0.3784070611000061\n",
      "epoch =  81  step =  200  of total steps  201  loss =  0.33805304765701294\n",
      "epoch :  82  /  100  | TL :  0.2968135106548741  | VL :  0.23478377936407924\n",
      "epoch =  82  step =  50  of total steps  201  loss =  0.18427273631095886\n",
      "epoch =  82  step =  100  of total steps  201  loss =  0.14643675088882446\n",
      "epoch =  82  step =  150  of total steps  201  loss =  0.16075454652309418\n",
      "epoch =  82  step =  200  of total steps  201  loss =  0.2912862300872803\n",
      "epoch :  83  /  100  | TL :  0.29825962687013163  | VL :  0.2365310317836702\n",
      "epoch =  83  step =  50  of total steps  201  loss =  0.31983456015586853\n",
      "epoch =  83  step =  100  of total steps  201  loss =  0.1827252060174942\n",
      "epoch =  83  step =  150  of total steps  201  loss =  0.26709678769111633\n",
      "epoch =  83  step =  200  of total steps  201  loss =  0.183475062251091\n",
      "epoch :  84  /  100  | TL :  0.3083568730609334  | VL :  0.23240531515330076\n",
      "epoch =  84  step =  50  of total steps  201  loss =  0.22386965155601501\n",
      "epoch =  84  step =  100  of total steps  201  loss =  0.3271198570728302\n",
      "epoch =  84  step =  150  of total steps  201  loss =  0.1451701670885086\n",
      "epoch =  84  step =  200  of total steps  201  loss =  0.2647281885147095\n",
      "epoch :  85  /  100  | TL :  0.31783713771039573  | VL :  0.25778811844065785\n",
      "epoch =  85  step =  50  of total steps  201  loss =  0.43040239810943604\n",
      "epoch =  85  step =  100  of total steps  201  loss =  0.14497911930084229\n",
      "epoch =  85  step =  150  of total steps  201  loss =  0.38995593786239624\n",
      "epoch =  85  step =  200  of total steps  201  loss =  0.2572198808193207\n",
      "epoch :  86  /  100  | TL :  0.32173356513923673  | VL :  0.24101581843569875\n",
      "epoch =  86  step =  50  of total steps  201  loss =  0.23910066485404968\n",
      "epoch =  86  step =  100  of total steps  201  loss =  0.4084475040435791\n",
      "epoch =  86  step =  150  of total steps  201  loss =  0.13633255660533905\n",
      "epoch =  86  step =  200  of total steps  201  loss =  0.393359512090683\n",
      "epoch :  87  /  100  | TL :  0.2933935110100466  | VL :  0.2569144801236689\n",
      "epoch =  87  step =  50  of total steps  201  loss =  0.1512928605079651\n",
      "epoch =  87  step =  100  of total steps  201  loss =  0.3531481921672821\n",
      "epoch =  87  step =  150  of total steps  201  loss =  0.1535779982805252\n",
      "epoch =  87  step =  200  of total steps  201  loss =  0.3834499716758728\n",
      "epoch :  88  /  100  | TL :  0.28493628250574  | VL :  0.24541931226849556\n",
      "epoch =  88  step =  50  of total steps  201  loss =  0.1710035353899002\n",
      "epoch =  88  step =  100  of total steps  201  loss =  0.3824246823787689\n",
      "epoch =  88  step =  150  of total steps  201  loss =  0.22586923837661743\n",
      "epoch =  88  step =  200  of total steps  201  loss =  0.2371288537979126\n",
      "epoch :  89  /  100  | TL :  0.2882078862308863  | VL :  0.24643920874223113\n",
      "epoch =  89  step =  50  of total steps  201  loss =  0.3365078270435333\n",
      "epoch =  89  step =  100  of total steps  201  loss =  0.15239664912223816\n",
      "epoch =  89  step =  150  of total steps  201  loss =  0.16534703969955444\n",
      "epoch =  89  step =  200  of total steps  201  loss =  0.2667176425457001\n",
      "epoch :  90  /  100  | TL :  0.3122037124648616  | VL :  0.22880833083763719\n",
      "saving model\n",
      "epoch =  90  step =  50  of total steps  201  loss =  0.5422722697257996\n",
      "epoch =  90  step =  100  of total steps  201  loss =  0.44972139596939087\n",
      "epoch =  90  step =  150  of total steps  201  loss =  0.3448857069015503\n",
      "epoch =  90  step =  200  of total steps  201  loss =  0.32490628957748413\n",
      "epoch :  91  /  100  | TL :  0.3112139443853008  | VL :  0.23700041603296995\n",
      "epoch =  91  step =  50  of total steps  201  loss =  0.5013622641563416\n",
      "epoch =  91  step =  100  of total steps  201  loss =  1.101981282234192\n",
      "epoch =  91  step =  150  of total steps  201  loss =  0.22238820791244507\n",
      "epoch =  91  step =  200  of total steps  201  loss =  0.16782011091709137\n",
      "epoch :  92  /  100  | TL :  0.29970990695911853  | VL :  0.23591851443052292\n",
      "epoch =  92  step =  50  of total steps  201  loss =  0.45168420672416687\n",
      "epoch =  92  step =  100  of total steps  201  loss =  0.1245446652173996\n",
      "epoch =  92  step =  150  of total steps  201  loss =  0.22180671989917755\n",
      "epoch =  92  step =  200  of total steps  201  loss =  0.30285969376564026\n",
      "epoch :  93  /  100  | TL :  0.2913072292336184  | VL :  0.22513250587508082\n",
      "saving model\n",
      "epoch =  93  step =  50  of total steps  201  loss =  0.2359590232372284\n",
      "epoch =  93  step =  100  of total steps  201  loss =  0.1924617886543274\n",
      "epoch =  93  step =  150  of total steps  201  loss =  0.6886243224143982\n",
      "epoch =  93  step =  200  of total steps  201  loss =  0.3037150204181671\n",
      "epoch :  94  /  100  | TL :  0.30512696600968564  | VL :  0.2669175215996802\n",
      "epoch =  94  step =  50  of total steps  201  loss =  0.28284376859664917\n",
      "epoch =  94  step =  100  of total steps  201  loss =  0.3827824294567108\n",
      "epoch =  94  step =  150  of total steps  201  loss =  0.19702592492103577\n",
      "epoch =  94  step =  200  of total steps  201  loss =  0.44205033779144287\n",
      "epoch :  95  /  100  | TL :  0.30264226866153937  | VL :  0.24767168425023556\n",
      "epoch =  95  step =  50  of total steps  201  loss =  0.5773036479949951\n",
      "epoch =  95  step =  100  of total steps  201  loss =  0.22044265270233154\n",
      "epoch =  95  step =  150  of total steps  201  loss =  0.18735425174236298\n",
      "epoch =  95  step =  200  of total steps  201  loss =  0.13178251683712006\n",
      "epoch :  96  /  100  | TL :  0.2998593454411374  | VL :  0.24040822638198733\n",
      "epoch =  96  step =  50  of total steps  201  loss =  0.239935964345932\n",
      "epoch =  96  step =  100  of total steps  201  loss =  0.291557252407074\n",
      "epoch =  96  step =  150  of total steps  201  loss =  0.20458965003490448\n",
      "epoch =  96  step =  200  of total steps  201  loss =  0.33639538288116455\n",
      "epoch :  97  /  100  | TL :  0.3065089816922572  | VL :  0.2352840704843402\n",
      "epoch =  97  step =  50  of total steps  201  loss =  0.27683764696121216\n",
      "epoch =  97  step =  100  of total steps  201  loss =  0.4923291802406311\n",
      "epoch =  97  step =  150  of total steps  201  loss =  0.38004472851753235\n",
      "epoch =  97  step =  200  of total steps  201  loss =  0.3266274034976959\n",
      "epoch :  98  /  100  | TL :  0.3113200760154582  | VL :  0.23033263580873609\n",
      "epoch =  98  step =  50  of total steps  201  loss =  0.18629887700080872\n",
      "epoch =  98  step =  100  of total steps  201  loss =  0.2092566192150116\n",
      "epoch =  98  step =  150  of total steps  201  loss =  0.13844721019268036\n",
      "epoch =  98  step =  200  of total steps  201  loss =  0.1796165406703949\n",
      "epoch :  99  /  100  | TL :  0.2991179142202904  | VL :  0.24004241172224283\n",
      "epoch =  99  step =  50  of total steps  201  loss =  0.38601982593536377\n",
      "epoch =  99  step =  100  of total steps  201  loss =  0.36523857712745667\n",
      "epoch =  99  step =  150  of total steps  201  loss =  0.25564008951187134\n",
      "epoch =  99  step =  200  of total steps  201  loss =  0.1803300827741623\n",
      "epoch :  100  /  100  | TL :  0.2948301003421124  | VL :  0.24956152541562915\n"
     ]
    }
   ],
   "source": [
    "experiment = Experiment(api_key=\"IOZ5docSriEdGRdQmdXQn9kpu\",\n",
    "                        project_name=\"kd0\", workspace=\"akshaykvnit\")\n",
    "experiment.log_parameters(hyper_params)\n",
    "filename = '../saved_models/stage' + str(hyper_params['stage']) + '/model' + str(hyper_params['repeated']) + '.pt'\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr = hyper_params[\"learning_rate\"])\n",
    "total_step = len(data.train_ds) // hyper_params[\"batch_size\"]\n",
    "train_loss_list = list()\n",
    "val_loss_list = list()\n",
    "min_val = 100\n",
    "for epoch in range(hyper_params[\"num_epochs\"]):\n",
    "    trn = []\n",
    "    net.train()\n",
    "    for i, (images, labels) in enumerate(data.train_dl) :\n",
    "        if torch.cuda.is_available():\n",
    "            images = torch.autograd.Variable(images).cuda().float()\n",
    "            labels = torch.autograd.Variable(labels).cuda()\n",
    "        else : \n",
    "            images = torch.autograd.Variable(images).float()\n",
    "            labels = torch.autograd.Variable(labels)\n",
    "\n",
    "        y_pred = net(images)\n",
    "        y_pred2 = mdl(images)\n",
    "        \n",
    "        save_torch\n",
    "\n",
    "        loss = F.mse_loss(sf2[hyper_params[\"stage\"]].features, sf[hyper_params[\"stage\"]].features)\n",
    "        trn.append(loss.item())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "#         torch.nn.utils.clip_grad_value_(net.parameters(), 10)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 50 == 49 :\n",
    "            print('epoch = ', epoch, ' step = ', i + 1, ' of total steps ', total_step, ' loss = ', loss.item())\n",
    "\n",
    "    train_loss = (sum(trn) / len(trn))\n",
    "    train_loss_list.append(train_loss)\n",
    "\n",
    "#     net.eval()\n",
    "#     val = []\n",
    "#     with torch.no_grad() :\n",
    "#         for i, (images, labels) in enumerate(data.valid_dl) :\n",
    "#             if torch.cuda.is_available():\n",
    "#                 images = torch.autograd.Variable(images).cuda().float()\n",
    "#                 labels = torch.autograd.Variable(labels).cuda()\n",
    "#             else : \n",
    "#                 images = torch.autograd.Variable(images).float()\n",
    "#                 labels = torch.autograd.Variable(labels)\n",
    "\n",
    "#             # Forward pass\n",
    "#             y_pred = net(images)\n",
    "#             y_pred2 = mdl(images)\n",
    "#             loss = F.mse_loss(sf[hyper_params[\"stage\"]].features, sf2[hyper_params[\"stage\"]].features)\n",
    "#             val.append(loss.item())\n",
    "\n",
    "#     val_loss = sum(val) / len(val)\n",
    "#     val_loss_list.append(val_loss)\n",
    "#     print('epoch : ', epoch + 1, ' / ', hyper_params[\"num_epochs\"], ' | TL : ', train_loss, ' | VL : ', val_loss)\n",
    "#     experiment.log_metric(\"train_loss\", train_loss)\n",
    "#     experiment.log_metric(\"val_loss\", val_loss)\n",
    "\n",
    "#     if val_loss < min_val :\n",
    "#         print('saving model')\n",
    "#         min_val = val_loss\n",
    "#         torch.save(net.state_dict(), filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# learn = Learner(data, net, metrics = accuracy)\n",
    "net.cpu()\n",
    "net.load_state_dict(torch.load('../saved_models/stage5/model0.pt', map_location = 'cpu'))\n",
    "net = net.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAgAElEQVR4nO3dd3hUVf7H8fdJryQhCYQklNB7DUWKgCgCq2AH1F2xYW+7FnZ/9rLrrq66KIKuBdeOYEEFEaV3QieUACGQECC998z5/XEGSEhCEjJhmMn39Tw8ydy5uXPujH7m3O8991yltUYIIYTjc7F3A4QQQtiGBLoQQjgJCXQhhHASEuhCCOEkJNCFEMJJuNnrhUNCQnS7du3s9fJCCOGQtmzZkqa1Dq3uObsFert27YiJibHXywshhENSSh2p6TkpuQghhJOQQBdCCCchgS6EEE7CbjV0IYTzKS0tJSkpiaKiIns3xeF5eXkRGRmJu7t7nf9GAl0IYTNJSUn4+/vTrl07lFL2bo7D0lqTnp5OUlISUVFRdf47KbkIIWymqKiI4OBgCfMGUkoRHBxc7yMdCXQhhE1JmNvG+byPDhfo+0/k8vqS/WTkl9i7KUIIcVFxuECPT83jneUHOZkjJ12EEKIihwt0H09zHregpNzOLRFCXGyysrJ499136/13EyZMICsrq95/N23aNObPn1/vv2ssjhfoHq4AFEqgCyHOUlOgl5efOy8WLVpEYGBgYzXrgnG4YYve7ibQ80vK7NwSIcS5vPBjLHuSc2y6ze7hzXju6h41Pj9jxgwOHTpE3759cXd3x8/Pj1atWrF9+3b27NnDNddcQ2JiIkVFRTzyyCNMnz4dODO3VF5eHuPHj2f48OGsW7eOiIgIfvjhB7y9vWtt2++//87jjz9OWVkZAwcOZPbs2Xh6ejJjxgwWLlyIm5sbY8eO5fXXX+ebb77hhRdewNXVlYCAAFatWmWT96fWHrpS6iOlVIpSanct6w1USpUrpW6wSctqID10IURNXn31VTp06MD27dt57bXX2LRpE6+88gp79uwB4KOPPmLLli3ExMQwc+ZM0tPTq2zjwIEDPPDAA8TGxhIYGMiCBQtqfd2ioiKmTZvG119/za5duygrK2P27NlkZGTw3XffERsby86dO3n66acBePHFF1myZAk7duxg4cKFNtv/uvTQ5wLvAP+raQWllCvwT2CJbZpVM1+poQvhEM7Vk75QBg0aVOnCnJkzZ/Ldd98BkJiYyIEDBwgODq70N1FRUfTt2xeAAQMGkJCQUOvr7N+/n6ioKDp37gzAbbfdxqxZs3jwwQfx8vLirrvu4g9/+ANXXXUVAMOGDWPatGncdNNNXHfddbbYVaAOPXSt9Sogo5bVHgIWACm2aNS5eFt76AVSchFC1MLX1/f07ytWrOC3335j/fr17Nixg379+lV74Y6np+fp311dXSkrqz1rtNbVLndzc2PTpk1cf/31fP/994wbNw6AOXPm8PLLL5OYmEjfvn2rPVI4Hw2uoSulIoBrgcuAgbWsOx2YDtCmTZvzej0f91OBLj10IURl/v7+5ObmVvtcdnY2QUFB+Pj4sG/fPjZs2GCz1+3atSsJCQkcPHiQjh078umnnzJy5Ejy8vIoKChgwoQJDBkyhI4dOwJw6NAhBg8ezODBg/nxxx9JTEyscqRwPmxxUvQt4CmtdXltVzZprd8H3geIjo6u/iutFm6uLni4uUigCyGqCA4OZtiwYfTs2RNvb29atmx5+rlx48YxZ84cevfuTZcuXRgyZIjNXtfLy4uPP/6YG2+88fRJ0XvvvZeMjAwmTZpEUVERWmvefPNNAJ544gkOHDiA1poxY8bQp08fm7RD1XSoUGklpdoBP2mte1bz3GHgVJKHAAXAdK319+faZnR0tD7fOxb1ffFXJvYJ58VJVZojhLCjvXv30q1bN3s3w2lU934qpbZoraOrW7/BPXSt9ekzDkqpuZjgP2eYN5SPu6v00IUQ4iy1BrpS6ktgFBCilEoCngPcAbTWcxq1dTXw8XSTYYtCiAvmgQceYO3atZWWPfLII9x+++12alH1ag10rfXUum5Maz2tQa2pIx8PV7mwSAhxwcyaNcveTagTh7v0H8zVolJyEUKIyhwy0H2l5CKEEFU4ZKB7S8lFCCGqcMhA93F3lR66EEKcxSED3dfTTWroQgib8PPzq/G5hIQEevZ0nOtdHDLQvT1cZS4XIYQ4i8PNhw6m5FJariktt+Du6pDfSUI4v8Uz4MQu224zrBeMf/Wcqzz11FO0bduW+++/H4Dnn38epRSrVq0iMzOT0tJSXn75ZSZNmlSvly4qKuK+++4jJiYGNzc33njjDUaPHk1sbCy33347JSUlWCwWFixYQHh4ODfddBNJSUmUl5fzzDPPMHny5PPe7bpyzECvMIVugLcEuhDijClTpvDoo4+eDvR58+bxyy+/8Nhjj9GsWTPS0tIYMmQIEydOpLb5pyo6NRZ9165d7Nu3j7FjxxIXF8ecOXN45JFHuOWWWygpKaG8vJxFixYRHh7Ozz//DJiJwS4Exwz0ClPoBni727k1Qohq1dKTbiz9+vUjJSWF5ORkUlNTCQoKolWrVjz22GOsWrUKFxcXjh07xsmTJwkLC6vzdtesWcNDDz0EmNkV27ZtS1xcHJdccgmvvPIKSUlJXHfddXTq1IlevXrx+OOP89RTT3HVVVcxYsSIxtrdShyye3sm0OXEqBCiqhtuuIH58+fz9ddfM2XKFD7//HNSU1PZsmUL27dvp2XLltXOhX4uNU1kePPNN7Nw4UK8vb258sorWbZsGZ07d2bLli306tWLv/71r7z44ou22K1aOWgP3TRbhi4KIaozZcoU7r77btLS0li5ciXz5s2jRYsWuLu7s3z5co4cOVLvbV566aV8/vnnXHbZZcTFxXH06FG6dOlCfHw87du35+GHHyY+Pp6dO3fStWtXmjdvzq233oqfnx9z5861/U5Ww0ED3Xqj6GIZ6SKEqKpHjx7k5uYSERFBq1atuOWWW7j66quJjo6mb9++dO3atd7bvP/++7n33nvp1asXbm5uzJ07F09PT77++ms+++wz3N3dCQsL49lnn2Xz5s088cQTuLi44O7uzuzZsxthL6uq03zojaEh86FvPZrJde+u4+PbBzK6Swsbt0wIcb5kPnTbqu986A5ZQ/eVkosQQlQhJRchRJO3a9cu/vjHP1Za5unpycaNG+3UovPjkIHubQ30wlLpoQtxsdFa12t898WgV69ebN++3d7NqOR8yuEOXXKRYYtCXFy8vLxIT08/rzASZ2itSU9Px8vLq15/55A9dC93F5SCAim5CHFRiYyMJCkpidTUVHs3xeF5eXkRGRlZr79xyEBXSsldi4S4CLm7uxMVFVX7iqJROGTJBczFRQVSQxdCiNMcONBdpeQihBAVOHagS8lFCCFOc+hAl2GLQghxhgMHuptcWCSEEBXUGuhKqY+UUilKqd01PH+LUmqn9d86pVQf2zezKm8puQghRCV16aHPBcad4/nDwEitdW/gJeB9G7SrVr5SchFCiEpqHYeutV6llGp3jufXVXi4AajfSPjz5O3hRn6xBLoQQpxi6xr6ncDimp5USk1XSsUopWIaeiWZj4crhSVSQxdCiFNsFuhKqdGYQH+qpnW01u9rraO11tGhoaENej1fD1cKSstlzgghhLCySaArpXoDHwCTtNbptthmbbw93NAaikotF+LlhBDiotfgQFdKtQG+Bf6otY5reJPq5syNoqXsIoQQUIeTokqpL4FRQIhSKgl4DnAH0FrPAZ4FgoF3rXMgl9V0eyRbOhPo5QQ39osJIYQDqMsol6m1PH8XcJfNWlRHPjInuhBCVOLAV4pKyUUIISpy+ECXG0ULIYThwIFuSi75EuhCCAE4cKB7S8lFCCEqcdhA9/WUkosQQlTksIHu4y4lFyGEqMhhA9379ElRKbkIIQQ4cKB7uLng7qqkhy6EEFYOG+gA3u6uUkMXQggrxwv0pC3w3b1QkIGPh5uMchFCCCvHC/SCNNjxJaQfxMfDVUouQghh5XiB3ry9+ZkRj4+nlFyEEOIUxwv0wDagXEygu0vJRQghTnG8QHfzhIBIyIjH28NVZlsUQggrxwt0MGWXjHh8PSXQhRDiFIcOdG93N6mhCyGEleMGemEmwS755EsNXQghAIcN9A4AhOvjUnIRQggrBw10M3SxZVkyJWUWysotdm6QEELYn2MGelA7QBFacgyAglLppQshhGMGursXNIugeUkSIHOiCyEEOGqgAzSPIqAwEUDq6EIIgUMHenv88k2g5xfLSBchhKg10JVSHymlUpRSu2t4XimlZiqlDiqldiql+tu+mdVo3h7P4nT8KCC3SAJdCCHq0kOfC4w7x/PjgU7Wf9OB2Q1vVh1YR7q0VSnEJmdfkJcUQoiLWa2BrrVeBWScY5VJwP+0sQEIVEq1slUDa2QN9P7+GWw5ktnoLyeEEBc7W9TQI4DECo+TrMuqUEpNV0rFKKViUlNTG/aqzaMAGByQTcyRTLTWDdueEEI4OFsEuqpmWbXpqrV+X2sdrbWODg0NbdireviCXxhdPVJJzS0mKbOwYdsTQggHZ4tATwJaV3gcCSTbYLu1a96eVpbjAGw9KmUXIUTTZotAXwj8yTraZQiQrbU+boPt1q55e3zyjuLr4Sp1dCFEk+dW2wpKqS+BUUCIUioJeA5wB9BazwEWAROAg0ABcHtjNbaK5lGo7ccZHOklgS6EaPJqDXSt9dRantfAAzZrUX1YR7qMCs3j+U2K/OIyfD1r3SUhhHBKjnulKJwe6dK/WTYWDTsSs+zcICGEsB/HDvSgdgB0cEtDKaTsIoRo0hw70L2DwCsA77xEOrfwl5EuQogmzbEDHUwvPTOB/m2D2Ho0C4tFLjASQjRNjh/ogW0h6wgD2gaRXVhKfFqevVskhBB24fiBHtQOMo8Q3SYAgPXx55p2RgghnJdzBHp5MW09cogM8mbNgQbOESOEEA7KOQIdUFlHGNEphHUH0+Wm0UKIJslpAp3MBEZ0CiW3uIwdSTIeXQjR9Dh+oAe0BhRkHmFoh2CUgtUH0uzdKiGEuOAcP9DdPCAgEjITCPTxoHdkoAS6EKJJcvxAh9Nj0QEu7RTC9sQscopK7dokIYS40Jwk0NueDvThHUMot2jWH0q3b5uEEOICc5JAbwd5J6C0kH5tgvD1cGW1DF8UQjQxzhHoge3Mz6yjeLi5MKR9sNTRhRBNjnMEeoWhiwAjOoVwJL2Ao+kFdmuSEEJcaM4Z6J3NDahXxqXYpz1CCGEHzhHoviHg7ns60NuH+NIu2IeleyXQhRBNh3MEulKVhi4qpbiie0vWH0ojV4YvCiGaCOcIdLAOXTxy+uEV3cMoLdesjJPRLkKIpsGJAr2d6aFrc4OL/m0CCfJxZ+mek3ZtlhBCXCjOFeil+ZBvhiu6ubpwWdeWLN+XQqnMviiEaAKcK9DhdB0d4IruLckpKmPTYbnphRDC+dUp0JVS45RS+5VSB5VSM6p5vo1SarlSaptSaqdSaoLtm1qLFt3MzwNLTi+6tHMInm4uUnYRQjQJtQa6UsoVmAWMB7oDU5VS3c9a7Wlgnta6HzAFeNfWDa1VYBvoPgk2vgeFmQD4eLgxvGMIS/ecRGu5ebQQwrnVpYc+CDiotY7XWpcAXwGTzlpHA82svwcAybZrYj2MfAqKc2DD7NOLrujekmNZhew9nmuXJgkhxIVSl0CPABIrPE6yLqvoeeBWpVQSsAh4yCatq6+WPaDb1bBhDhSauxaN6dYSgOX75SIjIYRzq0ugq2qWnV2/mArM1VpHAhOAT5VSVbatlJqulIpRSsWkpjbS+PCRT0FxNmycA0CovyftQ3zZkSi3pRNCOLe6BHoS0LrC40iqllTuBOYBaK3XA15AyNkb0lq/r7WO1lpHh4aGnl+LaxPWC7peBRvehaJsAHpHBsh9RoUQTq8ugb4Z6KSUilJKeWBOei48a52jwBgApVQ3TKDb7xLNS58wYb5rPgB9WgdyMqeYE9lFdmuSEEI0tloDXWtdBjwILAH2YkazxCqlXlRKTbSu9hfgbqXUDuBLYJq257CSVn3AtwUkbgJMoANsl7KLEMKJudVlJa31IszJzorLnq3w+x5gmG2b1gBKQeRAOBYDQPdWzXBzUexIymJczzA7N04IIRqH81wperbIaEg/CAUZeLm70rWVPzulji6EcGJOHOgDzc9jWwDoExnIzsRsLBa5wEgI4ZycN9DD+4FygaTNgKmj5xaXEZ+Wb+eGCSFE43DeQPf0gxY9Tgd6X+uJURmPLoRwVs4b6GDq6ElbwGKhQ6gfPh6uMh5dCOG0nDzQB5qrRtMP4Oqi6BURwI6kbHu3SgghGoXzBzpUKrvsTc6huKzcjo0SQojG4dyBHtwRvAIqnRgtKbewT2ZeFEI4IecOdBcXiIiGJHOBUe/IAAC2Hc20Z6uEEKJROHeggym7pOyB4lwiAr3p1MKPOSvjycwvsXfLhBDCpppGoGsLxK9EaQtvTu5Len4xTy7YKXcxEkI4FecP9Ij+oFzh61vg5Rb0/GYE7/ZPZumek3y24Yi9WyeEEDbj/IHu0xzuXApXvQlDH4ayIi4vWMyoLqG89PNe9h7PsXcLhRDCJpw/0AEiB0D0HXD5c9B9EuroOl6/rhsB3u48/OU2ikplGKMQwvE1jUCvKGoklBYQkrmTf9/YhwMpebzy8157t0oIIRqs6QV6u+Fm0q7DK7m0cyh3DY/i0w1H+G3PSXu3TAghGqTpBbp3oJmJMX4FAE+M60L3Vs14csFOUnLkFnVCCMfV9AIdTNnl2BYozsXTzZWZU/tRUFLGjG93yVBGIYTDapqB3n4UWMrgyDoAOrbw4/GxXVi2L4UlsSfs2jQhhDhfTTPQWw8GN6/TZReAaUPb0a1VM174cQ95xWX2a5sQQpynphno7l4m1ONXnl7k5urCK9f25EROEW8tjbNj44QQ4vw0zUAHU3ZJiYW8lNOL+rcJYsrANny8LkEuOBJCOJwmHOgjzc/DqyotfmpcFwK93Xnxxz12aJQQQpy/phvorfqaudLPCvRAHw/uGdme9fHpxCbL3Y2EEI6jToGulBqnlNqvlDqolJpRwzo3KaX2KKVilVJf2LaZjcDFFSIGwLGtVZ6aHN0Gb3dX5q5NuPDtEkKI81RroCulXIFZwHigOzBVKdX9rHU6AX8FhmmtewCPNkJbbS+8v5krvaSg0uIAH3eu6x/BDzuSSc8rtlPjhBCifurSQx8EHNRax2utS4CvgElnrXM3MEtrnQmgtU7BEUT0B10OJ3ZWeer2Ye0oKbPw5aajdmiYEELUX10CPQJIrPA4ybqsos5AZ6XUWqXUBqXUuOo2pJSarpSKUUrFpKamnl+LbSm8v/lZTdmlYwt/RnQK4dMNRygtt1zghgkhRP3VJdBVNcvOvj7eDegEjAKmAh8opQKr/JHW72uto7XW0aGhofVtq+01awX+4ZBcNdAB7hgWxcmcYhbtOn6BGyaEEPVXl0BPAlpXeBwJJFezzg9a61Kt9WFgPybgL34R/avtoQOM7BxKVIgvL/20h38s2suOxCyZ60UIcdGqS6BvBjoppaKUUh7AFGDhWet8D4wGUEqFYEow8bZsaKMJ7wcZh6Awq8pTLi6KN27qQ4/wAD5cc5hJs9Zy45z1lEkJRghxEao10LXWZcCDwBJgLzBPax2rlHpRKTXRutoSIF0ptQdYDjyhtU5vrEbbVIS1jp68rdqn+7UJ4pM7BhHz9OU8Oa4LMUcy+WpzYrXrCiGEPdVpHLrWepHWurPWuoPW+hXrsme11gutv2ut9Z+11t211r201l81ZqNtKryf+Vmxjr7sZVjyf5VWC/Tx4L6RHRgU1Zw3l8aRW1R6ARsphBC1a7pXip7iHQTN25+pox/bCqteg/XvwN6fKq2qlOLpP3QjPb+Ed1ccskNjhRCiZhLoYIYvJm8DrWHps+ATAi16wM9/qVJb7x0ZyLX9IvhwzWGSMgtq2KAQQlx4Euhg6ug5x2DrJ5CwGkbNgGtmQX4KLH2myupPXNkFBbz44x4y8ksqPZeSU8S2o5kXqOFCCHGGm70bcFE4dYHRoichuCMMmAau7jD0IVj7H+h1I0Rdemb1QG8eHN2Rfy+N4/d9v3FJ+2B6hDdj7aE0dh8z0+7+/PBweoQH2GFnhBBNlfTQAVr1BuUC5cVwxYsmzAFG/dXU1396DMornwR98LKO/PTQcO65tD1JmQX8d3U8Xm6uPHp5J1xdlFyMJIS44KSHDuDha2ZedPOCLhPOLHf3hiv/Dl9Oge2fm567lVKKnhEB9IwI4Ikru1BcZsHL3RWAmIRMFu06weNju6BUdRfaCiGE7UkP/ZRbv4Wb58HZAdx5nLld3YpXobSw2j9VSp0Oc4BxPcM4nJbP/pO5jdliIYSoRAL9FK9m4OFTdblScPnzkHscNr1fp01d2SMMpWDxrhM2baIQQpyLBHpdtB0KncbC6jeqnSLgbKH+ngxq15zFu6WOLoS4cCTQ6+qyZ6AoC9bNrNPq43uGEXcyj4MpeQCUlFn4ZfcJsgvkClMhROOQQK+rVr3N8MV170Dq/lpXH9ezFQC/7D5OWl4xt364kXs/28Lwfy7j9SX7yTxr/LoQQjSUBHp9jH3FjIj57p4qwxjPFhbgRf82gcyLSWLi22vYkZjFM1d1Z0TnEGatOMiwfy5jXoxM8iWEsB0J9PrwbwlXv2WmCVj971pXn9CrFUczClBKseC+odw5PIp3bxnAr49eSr82gTw5fyfP/bBb7ogkhLAJGYdeX90nQe/JZgKvTmPPTL9bjckDW1NcZmHywNaE+HmeXt6ppT+f3D6If/6yj/+uPszeE7m8fE1POrf0vxB7IIRwUsped+CJjo7WMTExdnntBivMgncvAa8AuH991bHr9fDD9mM8tWAnRaUWurdqxrX9Ipg6uA1+nvJdK4SoSim1RWsdXd1zUnI5H96BZgKv1L2QsqdBm5rUN4LVT17Gc1d3x91V8cqivdz76RYsFrnVnRCifiTQz1fHMebnoeUN3lSovye3D4vihweH8/dre7HmYBqzV8p860KI+pFAP18BkRDcCeLPCvSc4/D7S1CUc16bnTqoNVf3CeeNpXFsTsiwQUOFEE2FBHpDdLgMEtZCWfGZZWvfgtWvwxeToaT+N8BQSvH3a3sSGeTNw19uqzLf+rm8sTSOh77chr3Oiwgh7EsCvSE6jIayQkjcaB6XlcDOeRDSGY6uh69vrRz2deTv5c6sm/uTnlfC0Fd/5+7/xTBvc+I5w33u2sPM/P0AP+5IZmVc6vnukRDCgUmgN0S74eDidqaOHvcLFGbAlf+AiTPh0O+w4C4oL6v3pntGBPDNvZcwObo1e5JzeHLBTga98ht3fbKZn3YmU1hSfnrd3/ee5MWf9nB5t5aEB3jx9rKD0ksXogmSsXEN4ekPkQOtdfTnYPsX4N/K9NxdXKE4D5b81VxZet37Zlk99GkdSJ/WgTw/URObnMOPO5L5fvsxftubgoerC/3aBDKgbRBz1yXQIzyAmVP7Mn9LEs/+EMv6+HSGdghpnP0WQlyUpIfeUO1HQ/J2M7/LgV/NRUengvuS+83Uu7vnw/f3gaX8XFuq0ambafx1QjfWzRjDF3cNZtqwduSXlDF75SECvd354LZofDzcuCm6NaH+nryz7KDNdlEI4Rjq1ENXSo0D/gO4Ah9orV+tYb0bgG+AgVprB71qqJ46jIYVf4eFD4Euh743V35++GMmyJe9BMoVJs0Cl/P/HnV1UQztGMLQjqb3nVVQgouLopmXuW2el7sr91zanpd/3suWIxkMaNv8vF9LCOFYag10pZQrMAu4AkgCNiulFmqt95y1nj/wMLCxMRp60QrvD54B5sRoRDSEdqm6zqWPg6UMVvwDOl8JPa6x2csH+nhUWXbz4DbMWn6QZ76P5fJuLfD3cqd9qC+XdW1xzlviaa3llnlCOLC6dBUHAQe11vFa6xLgK2BSNeu9BPwLKLJh+y5+rm4QNcL8fnbvvKJLnwC/lqb80sh8PNz424RuJGUW8Pbyg7yyaC93fhLDcwtjKa/mCtSi0nL+9cs++r20lJ93yk05hHBUdSm5RAAV53lNAgZXXEEp1Q9orbX+SSn1uA3b5xh6XgfHtkLP62tex8UVul8DWz+B4lxzQrUR3RjdmhujW2OxaPJLynhn2UHeWxVPclYhM6f2w8fDDYtFs+ZgGs/8sJsj6QW0CvDika+24e6qGNsjrFHbJ4SwvboEenXH4Ke7eUopF+BNYFqtG1JqOjAdoE2bNnVroSPoef25w/z0etfBpvdg/2LofVPjtwtwcVH4e7nz1wndiAzy5rmFsVzxxipcXRQnsosoKbcQFeLLF3cPpldEALd+uIkHvtjK+3+KZnSXFhekjUII26h1tkWl1CXA81rrK62P/wqgtf6H9XEAcAjIs/5JGJABTDzXiVGHnm3xfFks8FZPCOsNN391ZnlBBngHNWjWxrpatu8kH69NIMjHg/BAb9qH+DKxbzhe7mZkTnZBKTd/sIEDKXl8eFs0IzqF1rit3KJSfD3ccHGRursQF8q5ZlusS6C7AXHAGOAYsBm4WWsdW8P6K4DHaxvl0iQDHWDJ/8HG9+CJg2bWxoQ18L9rYMwzMOwRe7cOgIz8Em7+7wbi0/J5/48DGFVNTz0jv4TL31hJr4gAPrgtGnfXM6dj0vOKcXdzOT3ypjp5xWWk5RbTLsS3UfZBCGfVoOlztdZlwIPAEmAvME9rHauUelEpNdG2TW0CelwHllLY97OZyOub283jdW9D6cVxPrm5rwdf3j2EjqF+TP/fFpbvS6myzszfD5BZUMLKuFSeWrDz9JWpy/enMOr1FYx/azVH06ufy6bcovnjhxsZ/5/VpOReHPsshDOo04BorfUirXVnrXUHrfUr1mXPaq0XVrPuqCYzBv18RPSHwLawax7Mvx1K8mD8vyA/FXZ+VfvfXyBBvh58cfdguoT5M/3TGJbtO3n6uYS0fD7bcIQpA9vw2OWd+XbrMV5bsp/3VpbYMzQAABrjSURBVB7ijrmbiQj0Jr+kjBvfW8fBlLwq2/5gdTzbjmZRWFrO7BUyTbAQtiJXil5oSkGPayF+hZnAa+LbMGi6qauve8fU2S8SgT4efHbXYLqGNePez7ay7lAaAK8t2Y+HmwuPXdGJh8d0ZOqg1ry74hD/WLyPCb1a8e39Q/lq+hDKLTD5vfXsPpZ9epsHU3L599I4ruzRkhsHRPL5xqOcyD7TS9das/d4DqsPpLJwRzKLdh2nrI73XC0sKee1JfvYk3x+UxcL4ejkFnT2cGI3vDcCBt4NE/5llu2aDwvuhKlfQZfx9m3fWTLyS5j83nqOZRXy5JVdeP7HPTwyphOPXdEZgLJyCy//vJeIQG/uGhF1+uKkQ6l53PLfjaTmFXNN3wjuG9WBv3yzg6Pp+fz62EiKSssZ/foKpg5qw0vX9KS4rJxHv9rO4t0nKr3+Je2DmTm1H6H+nlXadkpuUSl3zo1hU0IGzX09mH/vJbQP9Wu8N0UIO2nQSdHG0qQDHSDrKAS0PjOypbwUZvYzy+5YfGY9reHwKtjxJfiGwCUPgv+FHyOeklPEje+t50h6ASF+nqx8YhS+dbjvaVpeMXNWHOKzjUcoKjU97ben9uPqPuEA/O27XXwTk8jiR0bwwo97WH0gjT9f0Zkh7YMJ8nFn29Esnl24m2Ze7rw9tR+D2wdXeY2sghJu+3gzu49l89S4Lry3Mh4vd1cW3DeUsAAv274R9fTD9mP8+9c4/n1THwa2k2kYRMNJoDuK9e+a2RmH/xncvU19fd/PkH7QTC9Qkmem6x1wm1mnWauq29C60YY/JmUW8OAX27hrRBRX9Q6v19+m5Bbx4erDuLkqHh/b5XQvPjmrkFGvrcDFBUrKLPzrhj7cMCCy0t/uPZ7D/Z9v5XBaPt7uroT6exLi54Gnmyturooj6QWcyC7inZv7MbZHGLuSspny/noigrx59PLO5BaVkl1YytGMAg6czONwWj7je4bx/MQelaY6+G3PSRbuSCazoITMghKCfDyYNrQdo7u0qPfQzKLScl78aQ9fbDwKwGVdW/DRtIE1rr/vRA4Rgd74n2NkUE1O5hRxNKNAvjCaCAl0R1GcC28PgDzrCUjlCpHREH0HdJ8EuSdgzRtmml7vILh1AbTqY9YtzDJzr+ckwx2/gFez6l8j7ldzcdMNH9e8zgX20k97+HT9EWZO7ce4ntUffeQWlTIvJonjWYWk5hWTnldCSZmFUosFV6V45PJOlcbMrzuYxrSPN1NSof7u7+VGpxZ++Hq6sfpAGk//oRt3jWgPwMq4VO6Yu5nmvmZ8fpCPO3EncknOLqJDqC9TB7WhS5g/7YJ9CQ/0xvUcAZ+YUcA9n25hz/Ec7h3ZAVcXeHfFIVY8Poq2wWeGaVosmt/2nuT9VfHEHMmkT2QAX02/BG+Puk+znFtUyqR31nI4PZ+PbhvI6K6NczHYjsQs2gb7VDt3kLiwJNAdSXmp+efmWfP86an74bPrTYhP/QKaRcCXUyDjMGgLdLsabpxbtaeeEQ/vjYTiHJjwOgy6u9F3py4sFk1WYSnNfW0bFieyi8gsKKGZtzv+Xm74e7qhlEJrzX2fbWXp3pN8escgmnm7M/m99bQJ9mXePUNO95JLyy0s2nWc91fFE1vhRKufpxtX9W7FjdGR9G8TVKmXvzkhg3s+3UJZuYW3pvTlsq4tOZlTxLBXlzFtaDuevqo7ANmFpdz83w3EJpue+ZU9wvh43WHG9Qhj1s39cXFRlFs0C7Yk0SrQq9oLvCruR+sgb9LzS/j+gWF0sOG5g9JyC/9cvI8P1hymdXNvPp42kI4tGnfaCnFuEujOKPuYCfWMQ+DuY8J78meQFAO/PQfjX4PB08+sX1oIH14BWYng1wJc3OG+tRfk6tSLUV5xGdfOWktaXjHuri64uii+u39YtTV3rTUnc4pJSM/nSHo+Gw9nsHjXCQpLy2kb7MOwjiEMaR9MTmEpL/wYS2SQDx/eFl3ppOxDX25j5f4UNvxtDN7urjzwxVZ+jT3Jq9f35pq+4bi5uvDhmsO89NMe7hnZnnE9wnj6+93EJufg7e7Kzw8Pr3KS972VZmTR03/oxrieYUx8Zy2BPu58/8Cwc17UVZ3U3GJ+3JHM8v0ptAv2ZXinEDq28GPGgp1sTsjk+v6RrIxLoaTMwpw/Drhobp6SVVDCrmPZ7Duey7ieYbRu7mPvJjU6CXRnVZBh7ltakAFTPofgDmbY41dT4eDvcOcSiBhg1l34sJkYbOrXpqTz48Nwx6/QxjrPmqXcTAHcenC976zkqA6n5TPxnTUAzL93KF3C6t7zzCsuY/Gu4yzadZzNCZnkFZvbDA7rGMy7Nw8gwKdyoMYkZHDDnPW8cm1PtIanv9/NjPFduXdkh9PraK159odYPt1wBICWzTx5eEwn/vXLftoG+7DgvqGnr8hdfSCV2z7axLiepkevlGJDfDq3frCRwe2b89Kknqe/ALILS/nvqnh+23uSoR1CmNg3nD6RARxJL2BlXCq/70th7cE0yi2aDqG+HM8uosB6i0Nvd1devb4Xk/pGkJhRwB1zN3M4LZ/Xb+zDNf0iKrX9nWUHWX0wDXdXhZuLC13C/LlvZAeCbHzkVVpu4butx/jv6ngOVLjOoWdEM767f1ilq5adkQS6M9Pa/Kt404yCDHjvUsg9bmZ1dPeBnGPmZhuXP29ujfdGN+gyAa57z/zN0udg7Vvm+eGPXfj9sJODKXm4KBo0xLGs3MKe4zkkZxUxpluLagNFa81Vb68hq6CU1LxiLmkfzMfTBlY52VpWbuGZH2Jp5uXGQ2M64efpxuJdx7nv8608MLoDf76iC3NWHuKNpXG0D/HluweG4VdhtNHXm4/yzPexlJRbuKxrC3pFBDB3XQLZhaX0axNI7LEcSsot+Hu5kVtkvoTaBvtwVe9WXNM3gk4t/Skps7D1aCY7ErMY061FpRJLdmEp93waw8bDGbxxUx+u7ReJ1ppXft7LB2sO0ysiAA83F0rKLMQmZ+Pn6cbDYzrxp0va4eF25n3JLy7j9V/3s/ZgGk//oTuXdq55zqCK7823W4/x9vIDJGYU0isigPG9wugTGciJ7CL+8s0OHh/bmQcv63TO7VgsmicX7GTJ7hP4errh6+lKt1bNeGpc1wb38LMLS9l0OIMN8enEJmdzVe9wbh7UxqbzHUmgN0VpB2DbZ1CSb8otzcJh5FNm/naAnx+Hrf+Dv+yDxE3w5WTwCjDTD9y3DkI6Vr/dshJwkxNj52NeTCJPzt9Jy2aeLHp4BMF+NY+rP9sT3+xg/tYkekcEsCMpm4l9wnn52p7VllZSc4v5bMMRPttwhPT8EkZ1CeXxsV3oGRFAdmEpS3afYFNCBr0jA7i0U2i959MpLCnnjrmb2Xg4nX/f1Ie4k3nMXnGIaUPb8dzV3U+fU4g7mcvLP+9lVVwqYc28GN8rjAm9WpFXVMbT3+8mObuQlv5enMgp4pbBbfjbhG7VDoXVWvPL7hO89ut+4lPz6RURwKOXd6pyw5YHv9jKktgT/PjQcLqG1XzC/51lB3j91zj+0KsVPh6u5BWXsTIuFYvWPDi6I3df2h5PtzNHqVprvolJ4pstiYT6e9I6yIfWzX2ICvElKsQXfy83lu45yQ/bk1ljPdLxdHMhPNCbw2n5DIpqzqvX9bLZdRES6KKqk7EweygMvhd2fAWBrWHy5+aCpxY9YNrPptefecTcaSllrxk7X5gBY56DEX+29x44nKLScp77IZbJg1rTv01Qvf42r7iMq2auJiW3mBcm9uCGAZG13l2qqLSclJxi2gTbvq58KtTXx6cDcMvgNrx8Tc9q27RifwqfbTjKqgOplJSZUUcdW/jxz+t70SM8gH//up8P1hwmxM+THuHNiAj0JtTfk7yiMjILStl7PIc9x3Po1MKPx6/swtjuLat9nYz8Esa+uZKwAK8aSy8r9qdw+9zNTOoTzpuT+1YaPvvST3tYvPsErZt7c+ewKG6Mbk251vzt2138tPM4nVr4Ua41SRmFlUZPnRIR6M1VfVoxuksL+rYOxNPNhW9iknj55z0UlVkY3SWUAW2DGNA2iB7hAadnOK0vCXRRvQ/Hmrq5ZzOYvsLU4Ld9Dj/cb0bBuLjCr8+YddsMMRc9pcWZE6/n6sVXJysRNs6Bk7vBww88fCGsFwy5v8nU7BsqI7+EsnILLZrZ92KpUwpKyvjLvB20CvDm6T90q7WskFdcxrJ9KeQWlXLDgMhKveBNhzP4aM1hEjMLOJZVSFZBKd7urgT5uBPazItbBrfh+v6R5xwuCvDL7uPc+9lWRnUJ5cHRHRnQ9swopKPpBVz9zhrCA7359r6h1Q4PXRmXyn9+i2Pr0Sz8vdzw83QjJbeYP1/R2ToEVWGxaE7kFJGQls/h9HxSc4sZ3jGE/m2Cqn0PUnKKePO3A6w7lMYR64R1t13Slhcm9az1Pa6OBLqoXux3MP8OM8Sxu/WuglrDp9dC/HLzOGokTHoHAq03JMk9CbMGmrlnbvvxzCiZpC1wfJup3xekmzH0vsHgEwwJa2H3ArNeeF9T1inOgexE6DwOrv+gfndwSoqBr242XwhdrzLnAnyamxO7Spkhn7ZksUBxthn7Ly6I0nLLeZ/cnLPyELNXHCK7sJQ+rQNpHeRN3Mlc4lPz8fFw5aeHRtR61LL1aCYfrjnMscxCnrmqOwPa2uazT80tZuvRTCICvekZEXBe25BAFzUrzDLzsleUeQS+vRt6TzYXNZ19eBvzMfz0KEx6F3rdAL+9ABtmnXnes5kJ19J889jDD/rfBkPuM6WdUzb9FxY/BS26wZQvTJ1fWyAvxUxeFr/cfIFMeA1amvHb5oTvSCgvAXcvyEyo3DblAuNehcH32OLdMb69x1yxO31F/Y5KhN0UlJSxYEsSn6w/QlFpOV3D/Onc0p+r+4TTrdXFcUHd+ZJAF7ZlscDH4yFtPzSLhJO7zIyRw/9seuSnTpqWFkJ+mvnCqKkHfvB3+Gaa6bGfza8lWMxIDP70A7TsaXrmB5bCHUvMVMQnY+HQMigrNjX/Q8tNGenu5RB2foe0lRz4DT6/HlDQqjfcubT2I4DMBFOqGvMshJx7xIUQ9SWBLmwvZS/MGWFGxlzzLnS+8vy3lXYQ9nxnflcupkffbji06G6ubv3kajNap9cNsPkD0wMfcl/128pPh3eHgG8oTF9et/JL5hFzNBC/0lx0NeY58PAxr/nuEHD1hNF/M/PXX/IgXPlKzdvSGv43CQ6vNCWhu36v2gat4dDvsPY/5rWLss08PWG9oO8tZj/dvCF5GyRtNss7jK7beykaR1aiuV/BJQ+aeZbsSAJdNI6TseAXZmrljSkzwYR61lFTM5/82bmvcI1bAl/cZG7pd8WLZllJvrk69tTRQ2EW7PwatnwCKda7KfqFmYuuWvaEKZ+ZL491b8O0RdBuGPz8F7PslgXQ6fLqX/vUSeWeN8Du+eak77h/nHn+yDr4/SU4ug4C2kDbS8yXopuXObo4uct8gaBNWQkAZfZj6EMVZucsMzdFcXE15yu8moFr/Sf2uqAy4s0EdNlJMHGm+fJsiJJ8c3L9Qvjmdoj9FjqMMeVBd/udmJZAF44vO8nU7oc+WLeTkz8+YsK63XBIPwS5yYAyZZxmrSBlH5QVQng/c66gw2UQ0hkO/mbmpUeZydL63WrCB0wJ6b+XmQnQRj4F0bdX7q3lpcA7A82RxbSf4ZcZZiK0m+dBQKQ513BgifniuPRxc17h7DH9x3fAznnmSKXNEDP52q9PmxPY0XeaI5Ptn5svjvwKtwYMbGvKQf4t6/e+HttqTlh3vBzaj6p9KoiiHHMEEtoFrv5P3Y6AkmLM0cjeH82XjnIxYX7LfLOd83FsC3w8wcxbNGnWmXacjIV5t5kRW+NeheZR57f9ijKPwMy+5r+VY1ug01jTqWjIyfcGzIoqgS6anuI800svLTRBHdLRnKjNTjRfDkFRMGCaGXVztvRDZkqFwiy4f13lL5D0Q+bLImG1+XIYch+06gtB7eD3F8zJ03vXQmhnM5rngzFm0rTSAtOLHv5nc77Box5jwy0Ws+21b5nHytWUuDpcZh6XFcGyV0zg3Lawbj31Y1th5T8h7pczy8L7wbBHzfulXMx2gqLOXIWstSk77fnBnLxuN8JMOeFVzWiNU2WlNW+Z98orwHwhDb7HfCF+MRnKi83tF/NSIGmTeb/Hvlz9Z1JRYZa5XqIo2/w71Y6jG8yoLXcf835bysz73X2SKWkV55ijr/oeGSyeAZv/C4/shAO/mgEBna6E8f+s/xdGwhpY8aopqw2YVr+/tZJAF6K+ystMD76mk7mHV5tATFhdefnop2HkE2cep+43Xw6dr7SeNG7AnOU755kvpD5TzYigSs99A9/edabEU15mSj7HtppZNU+dnC0pgKXPmNKRd5CpCQ+4Hfb9aMI383Dl7bYbAde+BwERsPF9WPyEmR6iWQR8fx+EdoWpX54Z1gpwYpcpTyVuBP9wc1TV/zbwrHClZOYR+PxGc2IdzBFGWZEJ6Kv/A32mVP8eaA3z/gj7F8Ptv5jJ6X54wLxOTpIJ7Klfmd7vkr+ZI5uKfELMtNO1fWmcUpgJb/QwRwKnpsnY/AEseuLMl1rfW0yv/Vylx4S15gK9Ux2BK16seR9rIYEuRGPJPmZCMOOw6QVG32m/qREWz4CNs02ox/1iatbKxfwbeDd0nWCCNi0OhjwAo2ZUnhO/vAwSVpmyii43+7biVdNTH/YILP87dBwDU760jihaBl//0fSGO4wxAZW8DTbMNiObxjwLfW6u+f0ozjUljBbdTa85L9WMeDqyxpS6mkWYcxpFOWZoa2S0uX3j0mdMT37oQ2Y7h5aZMkv7UXDtnMp19cRN5tyLZzMT8j89Znr4N39tzoukHzLDZ1NizTqe/ubLb9B08/vqN8zR0b1rzMnpU7KTzF3Etn1mHTqrTHmsw2hz7UabIaYcl7gZlr9shuH6tjDzJJ1dqqsnCXQhmoLyUvhkojnhGtYLRs6A1oNMEG/9xPQo/cPh2tkm/Ooi/ZA5p5C8zZzEvWdl5aOMzAQzJ9COr00PGUwpYcxz53c0Ul4KS5+FDe8CygyD9fA1oYw1qzqPO9MLP6WucwxlH4NPrzHbazvUnIh2cTNhXFp45oI3v5Yw+v/Me9eiG/zp++q3Z7FA8lbzpXJouSkdWcrMie3gjuaLwifEGuR31K/UVgMJdCGaiqJs04ttO7Ry4J3YbcoUA++sf9CWlZgvhKhLaz6JabHA0fWmx1+xJ3u+inJMLfzUZHJF2aZ8lLrfHAmcfTFcfeSnwxc3mlCPvsP8q3if3qQYc0I7abN5fOu35sikLopz4ch6M2w1eZs52TxoeuVyUwNJoAshREUWM997jfMIaW1G/6QdMKWpi+hGMOcK9Npv2242MA74D+AKfKC1fvWs5/8M3AWUAanAHVrrIw1qtRBCNJbaJoRTyoxEcTC1zn6jlHIFZgHjge7AVKVU97NW2wZEa617A/OBf9m6oUIIIc6tLtOZDQIOaq3jtdYlwFfApIoraK2Xa60LrA83AJG2baYQQoja1CXQI4DECo+TrMtqciewuLonlFLTlVIxSqmY1NTUurdSCCFEreoS6NWdDaj2TKpS6lYgGnituue11u9rraO11tGhobXfQ1AIIUTd1eWkaBJQYRJrIoHks1dSSl0O/B8wUmtdbJvmCSGEqKu69NA3A52UUlFKKQ9gCrCw4gpKqX7Ae8BErXVKNdsQQgjRyGoNdK11GfAgsATYC8zTWscqpV5USk20rvYa4Ad8o5TarpRaWMPmhBBCNJI6jUPXWi8CFp217NkKv9cwObQQQogLxW5XiiqlUoHzvfgoBEizYXMcRVPc76a4z9A097sp7jPUf7/baq2rHVVit0BvCKVUTE2XvjqzprjfTXGfoWnud1PcZ7DtftflpKgQQggHIIEuhBBOwlED/X17N8BOmuJ+N8V9hqa5301xn8GG++2QNXQhhBBVOWoPXQghxFkk0IUQwkk4XKArpcYppfYrpQ4qpWbYuz2NQSnVWim1XCm1VykVq5R6xLq8uVJqqVLqgPVnkL3b2hiUUq5KqW1KqZ+sj6OUUhut+/21dQoKp6GUClRKzVdK7bN+5pc0hc9aKfWY9b/v3UqpL5VSXs74WSulPlJKpSildldYVu3nq4yZ1nzbqZTqX5/XcqhAr+PNNpxBGfAXrXU3YAjwgHU/ZwC/a607Ab9bHzujRzDTTJzyT+BN635nYqZodib/AX7RWncF+mD23ak/a6VUBPAw5sY4PTF3Q5uCc37Wc4FxZy2r6fMdD3Sy/psOzK7PCzlUoFOHm204A631ca31VuvvuZj/wSMw+/qJdbVPgGvs08LGo5SKBP4AfGB9rIDLMHfCAifbb6VUM+BS4EMArXWJ1jqLJvBZY6Ye8VZKuQE+wHGc8LPWWq8CMs5aXNPnOwn4nzY2AIFKqVZ1fS1HC/T63mzD4Sml2gH9gI1AS631cTChD7SwX8sazVvAk4DF+jgYyLJOEgfO95m3x9yH92NrmekDpZQvTv5Za62PAa8DRzFBng1swbk/64pq+nwblHGOFuh1vtmGM1BK+QELgEe11jn2bk9jU0pdBaRorbdUXFzNqs70mbsB/YHZWut+QD5OVl6pjrVmPAmIAsIBX0y54WzO9FnXRYP+e3e0QK/TzTacgVLKHRPmn2utv7UuPnnq8Mv609nmnh8GTFRKJWDKaZdheuyB1sNycL7PPAlI0lpvtD6ejwl4Z/+sLwcOa61TtdalwLfAUJz7s66ops+3QRnnaIFe6802nIG1bvwhsFdr/UaFpxYCt1l/vw344UK3rTFprf+qtY7UWrfDfLbLtNa3AMuBG6yrOdV+a61PAIlKqS7WRWOAPTj5Z40ptQxRSvlY/3s/td9O+1mfpabPdyHwJ+tolyFA9qnSTJ1orR3qHzABiAMOAf9n7/Y00j4Oxxxm7QS2W/9NwNSTfwcOWH82t3dbG/E9GAX8ZP29PbAJOAh8A3jau3023te+QIz18/4eCGoKnzXwArAP2A18Cng642cNfIk5T1CK6YHfWdPniym5zLLm2y7MKKA6v5Zc+i+EEE7C0UouQgghaiCBLoQQTkICXQghnIQEuhBCOAkJdCGEcBIS6EII4SQk0IUQwkn8P8AjZFzGzdXJAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(range(hyper_params[\"num_epochs\"]), train_loss_list, label = 'train_loss') \n",
    "plt.plot(range(hyper_params[\"num_epochs\"]), val_loss_list, label = 'val_loss')\n",
    "plt.legend()\n",
    "plt.savefig('../figures/stage5/train_loss.jpg')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training the classifier only after stage-wise training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model on GPU\n",
      "0.0.weight torch.Size([64, 3, 7, 7])\n",
      "False\n",
      "0.2.weight torch.Size([64])\n",
      "False\n",
      "0.2.bias torch.Size([64])\n",
      "False\n",
      "2.0.0.weight torch.Size([64, 64, 3, 3])\n",
      "False\n",
      "2.0.2.weight torch.Size([64])\n",
      "False\n",
      "2.0.2.bias torch.Size([64])\n",
      "False\n",
      "2.1.conv1.0.weight torch.Size([64, 64, 3, 3])\n",
      "False\n",
      "2.1.conv1.2.weight torch.Size([64])\n",
      "False\n",
      "2.1.conv1.2.bias torch.Size([64])\n",
      "False\n",
      "3.0.0.weight torch.Size([128, 64, 3, 3])\n",
      "False\n",
      "3.0.2.weight torch.Size([128])\n",
      "False\n",
      "3.0.2.bias torch.Size([128])\n",
      "False\n",
      "3.1.conv1.0.weight torch.Size([128, 128, 3, 3])\n",
      "False\n",
      "3.1.conv1.2.weight torch.Size([128])\n",
      "False\n",
      "3.1.conv1.2.bias torch.Size([128])\n",
      "False\n",
      "4.0.0.weight torch.Size([256, 128, 3, 3])\n",
      "False\n",
      "4.0.2.weight torch.Size([256])\n",
      "False\n",
      "4.0.2.bias torch.Size([256])\n",
      "False\n",
      "4.1.conv1.0.weight torch.Size([256, 256, 3, 3])\n",
      "False\n",
      "4.1.conv1.2.weight torch.Size([256])\n",
      "False\n",
      "4.1.conv1.2.bias torch.Size([256])\n",
      "False\n",
      "5.0.0.weight torch.Size([512, 256, 3, 3])\n",
      "False\n",
      "5.0.2.weight torch.Size([512])\n",
      "False\n",
      "5.0.2.bias torch.Size([512])\n",
      "False\n",
      "5.1.conv1.0.weight torch.Size([512, 512, 3, 3])\n",
      "False\n",
      "5.1.conv1.2.weight torch.Size([512])\n",
      "False\n",
      "5.1.conv1.2.bias torch.Size([512])\n",
      "False\n",
      "8.weight torch.Size([256, 1024])\n",
      "True\n",
      "8.bias torch.Size([256])\n",
      "True\n",
      "9.weight torch.Size([10, 256])\n",
      "True\n",
      "9.bias torch.Size([10])\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "hyper_params = {\n",
    "    \"stage\": 5,\n",
    "    \"num_classes\": 10,\n",
    "    \"batch_size\": 64,\n",
    "    \"num_epochs\": 100,\n",
    "    \"learning_rate\": 1e-4\n",
    "}\n",
    "\n",
    "class Flatten(nn.Module) :\n",
    "    def forward(self, input):\n",
    "        return input.view(input.size(0), -1)\n",
    "\n",
    "def conv2(ni, nf) : \n",
    "    return conv_layer(ni, nf, stride = 2)\n",
    "\n",
    "class ResBlock(nn.Module):\n",
    "    def __init__(self, nf):\n",
    "        super().__init__()\n",
    "        self.conv1 = conv_layer(nf,nf)\n",
    "        \n",
    "    def forward(self, x): \n",
    "        return (x + self.conv1(x))\n",
    "\n",
    "def conv_and_res(ni, nf): \n",
    "    return nn.Sequential(conv2(ni, nf), ResBlock(nf))\n",
    "\n",
    "def conv_(nf) : \n",
    "    return nn.Sequential(conv_layer(nf, nf), ResBlock(nf))\n",
    "    \n",
    "net = nn.Sequential(\n",
    "    conv_layer(3, 64, ks = 7, stride = 2, padding = 3),\n",
    "    nn.MaxPool2d(3, 2, padding = 1),\n",
    "    conv_(64),\n",
    "    conv_and_res(64, 128),\n",
    "    conv_and_res(128, 256),\n",
    "    conv_and_res(256, 512),\n",
    "    AdaptiveConcatPool2d(),\n",
    "    Flatten(),\n",
    "    nn.Linear(2 * 512, 256),\n",
    "    nn.Linear(256, hyper_params[\"num_classes\"])\n",
    ")\n",
    "\n",
    "net.cpu()\n",
    "net.load_state_dict(torch.load('../saved_models/stage5/model0.pt', map_location = 'cpu'))\n",
    "\n",
    "if torch.cuda.is_available() : \n",
    "    net = net.cuda()\n",
    "    print('Model on GPU')\n",
    "    \n",
    "for name, param in net.named_parameters() : \n",
    "    print(name, param.shape)\n",
    "    param.requires_grad = False\n",
    "    if name[0] == '8' or name[0] == '9':\n",
    "        param.requires_grad = True\n",
    "    print(param.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_accuracy(dataloader, Net):\n",
    "    total = 0\n",
    "    correct = 0\n",
    "    Net.eval()\n",
    "    for i, (images, labels) in enumerate(dataloader):\n",
    "        images = torch.autograd.Variable(images).float()\n",
    "        labels = torch.autograd.Variable(labels).float()\n",
    "        \n",
    "        if torch.cuda.is_available() : \n",
    "            images = images.cuda()\n",
    "            labels = labels.cuda()\n",
    "\n",
    "        outputs = Net.forward(images)\n",
    "        outputs = F.log_softmax(outputs, dim = 1)\n",
    "\n",
    "        _, pred_ind = torch.max(outputs, 1)\n",
    "        \n",
    "        # converting to numpy arrays\n",
    "        labels = labels.data.cpu().numpy()\n",
    "        pred_ind = pred_ind.data.cpu().numpy()\n",
    "        \n",
    "        # get difference\n",
    "        diff_ind = labels - pred_ind\n",
    "        # correctly classified will be 1 and will get added\n",
    "        # incorrectly classified will be 0\n",
    "        correct += np.count_nonzero(diff_ind == 0)\n",
    "        total += len(diff_ind)\n",
    "\n",
    "    accuracy = correct / total\n",
    "    # print(len(diff_ind))\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "COMET INFO: ----------------------------\n",
      "COMET INFO: Comet.ml Experiment Summary:\n",
      "COMET INFO:   Data:\n",
      "COMET INFO:     url: https://www.comet.ml/akshaykvnit/kd0/68e29e7c7396443f9d5bc8c1d8fb34f7\n",
      "COMET INFO:   Metrics [count] (min, max):\n",
      "COMET INFO:     loss [104]                   : (0.09282064437866211, 2.5855929851531982)\n",
      "COMET INFO:     sys.gpu.0.free_memory [5]    : (10388766720.0, 10388766720.0)\n",
      "COMET INFO:     sys.gpu.0.gpu_utilization [5]: (0.0, 39.0)\n",
      "COMET INFO:     sys.gpu.0.total_memory       : (11721506816.0, 11721506816.0)\n",
      "COMET INFO:     sys.gpu.0.used_memory [5]    : (1332740096.0, 1332740096.0)\n",
      "COMET INFO:     sys.gpu.1.free_memory [5]    : (6221987840.0, 6221987840.0)\n",
      "COMET INFO:     sys.gpu.1.gpu_utilization [5]: (0.0, 0.0)\n",
      "COMET INFO:     sys.gpu.1.total_memory       : (6233391104.0, 6233391104.0)\n",
      "COMET INFO:     sys.gpu.1.used_memory [5]    : (11403264.0, 11403264.0)\n",
      "COMET INFO:     train_loss [5]               : (0.2542623481643734, 0.6748504299874329)\n",
      "COMET INFO:     val_acc [5]                  : (0.922, 0.934)\n",
      "COMET INFO:     val_loss [5]                 : (0.20247302344068885, 0.27922892197966576)\n",
      "COMET INFO: ----------------------------\n",
      "COMET INFO: Experiment is live on comet.ml https://www.comet.ml/akshaykvnit/kd0/38af5c0d61474e2d9273ada6ea156fe5\n",
      "\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch =  0  step =  50  of total steps  201  loss =  0.8784895539283752\n",
      "epoch =  0  step =  100  of total steps  201  loss =  0.4596188962459564\n",
      "epoch =  0  step =  150  of total steps  201  loss =  0.3605707585811615\n",
      "epoch =  0  step =  200  of total steps  201  loss =  0.28698939085006714\n",
      "epoch :  1  /  100  | TL :  0.6981583358488272  | VL :  0.2788702854886651  | VA :  94.0\n",
      "saving model\n",
      "epoch =  1  step =  50  of total steps  201  loss =  0.43684956431388855\n",
      "epoch =  1  step =  100  of total steps  201  loss =  0.257024884223938\n",
      "epoch =  1  step =  150  of total steps  201  loss =  0.261644184589386\n",
      "epoch =  1  step =  200  of total steps  201  loss =  0.2497667670249939\n",
      "epoch :  2  /  100  | TL :  0.3284727025091352  | VL :  0.22625178191810846  | VA :  93.4\n",
      "epoch =  2  step =  50  of total steps  201  loss =  0.31416481733322144\n",
      "epoch =  2  step =  100  of total steps  201  loss =  0.34493571519851685\n",
      "epoch =  2  step =  150  of total steps  201  loss =  0.43016624450683594\n",
      "epoch =  2  step =  200  of total steps  201  loss =  0.5088800191879272\n",
      "epoch :  3  /  100  | TL :  0.2902423296475885  | VL :  0.21517107030376792  | VA :  92.80000000000001\n",
      "epoch =  3  step =  50  of total steps  201  loss =  0.23638996481895447\n",
      "epoch =  3  step =  100  of total steps  201  loss =  0.26187020540237427\n",
      "epoch =  3  step =  150  of total steps  201  loss =  0.2791038751602173\n",
      "epoch =  3  step =  200  of total steps  201  loss =  0.264411598443985\n",
      "epoch :  4  /  100  | TL :  0.2746666660208014  | VL :  0.2047434956766665  | VA :  92.80000000000001\n",
      "epoch =  4  step =  50  of total steps  201  loss =  0.25847554206848145\n",
      "epoch =  4  step =  100  of total steps  201  loss =  0.3105309307575226\n",
      "epoch =  4  step =  150  of total steps  201  loss =  0.27651649713516235\n",
      "epoch =  4  step =  200  of total steps  201  loss =  0.261848509311676\n",
      "epoch :  5  /  100  | TL :  0.2633412720922807  | VL :  0.207779704593122  | VA :  93.4\n",
      "epoch =  5  step =  50  of total steps  201  loss =  0.1278153657913208\n",
      "epoch =  5  step =  100  of total steps  201  loss =  0.18454846739768982\n",
      "epoch =  5  step =  150  of total steps  201  loss =  0.2844318151473999\n",
      "epoch =  5  step =  200  of total steps  201  loss =  0.26928406953811646\n",
      "epoch :  6  /  100  | TL :  0.26328115917705186  | VL :  0.20662269508466125  | VA :  93.0\n",
      "epoch =  6  step =  50  of total steps  201  loss =  0.30374234914779663\n",
      "epoch =  6  step =  100  of total steps  201  loss =  0.22687304019927979\n",
      "epoch =  6  step =  150  of total steps  201  loss =  0.08077244460582733\n",
      "epoch =  6  step =  200  of total steps  201  loss =  0.13239970803260803\n",
      "epoch :  7  /  100  | TL :  0.24792141174499074  | VL :  0.20776139805093408  | VA :  93.60000000000001\n",
      "epoch =  7  step =  50  of total steps  201  loss =  0.18327292799949646\n",
      "epoch =  7  step =  100  of total steps  201  loss =  0.312863290309906\n",
      "epoch =  7  step =  150  of total steps  201  loss =  0.32132449746131897\n",
      "epoch =  7  step =  200  of total steps  201  loss =  0.22595562040805817\n",
      "epoch :  8  /  100  | TL :  0.2514310070494218  | VL :  0.2195680094882846  | VA :  92.60000000000001\n",
      "epoch =  8  step =  50  of total steps  201  loss =  0.21055161952972412\n",
      "epoch =  8  step =  100  of total steps  201  loss =  0.25211572647094727\n",
      "epoch =  8  step =  150  of total steps  201  loss =  0.19057099521160126\n",
      "epoch =  8  step =  200  of total steps  201  loss =  0.3789454698562622\n",
      "epoch :  9  /  100  | TL :  0.24296303695781313  | VL :  0.22576642222702503  | VA :  92.2\n",
      "epoch =  9  step =  50  of total steps  201  loss =  0.29526883363723755\n",
      "epoch =  9  step =  100  of total steps  201  loss =  0.3362762928009033\n",
      "epoch =  9  step =  150  of total steps  201  loss =  0.202084481716156\n",
      "epoch =  9  step =  200  of total steps  201  loss =  0.2655912935733795\n",
      "epoch :  10  /  100  | TL :  0.24155521841339805  | VL :  0.20563008030876517  | VA :  93.60000000000001\n",
      "epoch =  10  step =  50  of total steps  201  loss =  0.2488257735967636\n",
      "epoch =  10  step =  100  of total steps  201  loss =  0.372101366519928\n",
      "epoch =  10  step =  150  of total steps  201  loss =  0.14308716356754303\n",
      "epoch =  10  step =  200  of total steps  201  loss =  0.14739346504211426\n",
      "epoch :  11  /  100  | TL :  0.24629901071537785  | VL :  0.20945659140124917  | VA :  92.4\n",
      "epoch =  11  step =  50  of total steps  201  loss =  0.1840231865644455\n",
      "epoch =  11  step =  100  of total steps  201  loss =  0.40374600887298584\n",
      "epoch =  11  step =  150  of total steps  201  loss =  0.19593767821788788\n",
      "epoch =  11  step =  200  of total steps  201  loss =  0.18194571137428284\n",
      "epoch :  12  /  100  | TL :  0.24336236707903258  | VL :  0.2202791371382773  | VA :  92.0\n",
      "epoch =  12  step =  50  of total steps  201  loss =  0.24920272827148438\n",
      "epoch =  12  step =  100  of total steps  201  loss =  0.4710988998413086\n",
      "epoch =  12  step =  150  of total steps  201  loss =  0.3104923963546753\n",
      "epoch =  12  step =  200  of total steps  201  loss =  0.10895299166440964\n",
      "epoch :  13  /  100  | TL :  0.24499180468160714  | VL :  0.2121834997087717  | VA :  93.0\n",
      "epoch =  13  step =  50  of total steps  201  loss =  0.25743669271469116\n",
      "epoch =  13  step =  100  of total steps  201  loss =  0.4799787402153015\n",
      "epoch =  13  step =  150  of total steps  201  loss =  0.226236030459404\n",
      "epoch =  13  step =  200  of total steps  201  loss =  0.25183963775634766\n",
      "epoch :  14  /  100  | TL :  0.237427540864814  | VL :  0.21776134055107832  | VA :  92.60000000000001\n",
      "epoch =  14  step =  50  of total steps  201  loss =  0.13919349014759064\n",
      "epoch =  14  step =  100  of total steps  201  loss =  0.17916135489940643\n",
      "epoch =  14  step =  150  of total steps  201  loss =  0.2095920294523239\n",
      "epoch =  14  step =  200  of total steps  201  loss =  0.1525198519229889\n",
      "epoch :  15  /  100  | TL :  0.24328703388794146  | VL :  0.2209763708524406  | VA :  92.4\n",
      "epoch =  15  step =  50  of total steps  201  loss =  0.256217360496521\n",
      "epoch =  15  step =  100  of total steps  201  loss =  0.4744461476802826\n",
      "epoch =  15  step =  150  of total steps  201  loss =  0.3425982594490051\n",
      "epoch =  15  step =  200  of total steps  201  loss =  0.19082111120224\n",
      "epoch :  16  /  100  | TL :  0.2318665914756445  | VL :  0.21217101672664285  | VA :  93.2\n",
      "epoch =  16  step =  50  of total steps  201  loss =  0.2300536036491394\n",
      "epoch =  16  step =  100  of total steps  201  loss =  0.20806503295898438\n",
      "epoch =  16  step =  150  of total steps  201  loss =  0.2432381808757782\n",
      "epoch =  16  step =  200  of total steps  201  loss =  0.10463928431272507\n",
      "epoch :  17  /  100  | TL :  0.23323839067004212  | VL :  0.21110533736646175  | VA :  93.60000000000001\n",
      "epoch =  17  step =  50  of total steps  201  loss =  0.1391129046678543\n",
      "epoch =  17  step =  100  of total steps  201  loss =  0.1727377474308014\n",
      "epoch =  17  step =  150  of total steps  201  loss =  0.13050580024719238\n",
      "epoch =  17  step =  200  of total steps  201  loss =  0.15745554864406586\n",
      "epoch :  18  /  100  | TL :  0.23132630926904393  | VL :  0.22160714957863092  | VA :  92.60000000000001\n",
      "epoch =  18  step =  50  of total steps  201  loss =  0.3065853416919708\n",
      "epoch =  18  step =  100  of total steps  201  loss =  0.19926974177360535\n",
      "epoch =  18  step =  150  of total steps  201  loss =  0.16891729831695557\n",
      "epoch =  18  step =  200  of total steps  201  loss =  0.20695003867149353\n",
      "epoch :  19  /  100  | TL :  0.22946956911267927  | VL :  0.21192442066967487  | VA :  93.0\n",
      "epoch =  19  step =  50  of total steps  201  loss =  0.09637708216905594\n",
      "epoch =  19  step =  100  of total steps  201  loss =  0.281436562538147\n",
      "epoch =  19  step =  150  of total steps  201  loss =  0.321521520614624\n",
      "epoch =  19  step =  200  of total steps  201  loss =  0.22102399170398712\n",
      "epoch :  20  /  100  | TL :  0.23575785522585485  | VL :  0.2135665100067854  | VA :  92.60000000000001\n",
      "epoch =  20  step =  50  of total steps  201  loss =  0.20361152291297913\n",
      "epoch =  20  step =  100  of total steps  201  loss =  0.37990137934684753\n",
      "epoch =  20  step =  150  of total steps  201  loss =  0.13407979905605316\n",
      "epoch =  20  step =  200  of total steps  201  loss =  0.1882648915052414\n",
      "epoch :  21  /  100  | TL :  0.2319928347434274  | VL :  0.1973338397219777  | VA :  93.60000000000001\n",
      "epoch =  21  step =  50  of total steps  201  loss =  0.21543817222118378\n",
      "epoch =  21  step =  100  of total steps  201  loss =  0.14826373755931854\n",
      "epoch =  21  step =  150  of total steps  201  loss =  0.12414486706256866\n",
      "epoch =  21  step =  200  of total steps  201  loss =  0.223720520734787\n",
      "epoch :  22  /  100  | TL :  0.2287390042485586  | VL :  0.21034689154475927  | VA :  92.60000000000001\n",
      "epoch =  22  step =  50  of total steps  201  loss =  0.16999390721321106\n",
      "epoch =  22  step =  100  of total steps  201  loss =  0.5257681608200073\n",
      "epoch =  22  step =  150  of total steps  201  loss =  0.12145557999610901\n",
      "epoch =  22  step =  200  of total steps  201  loss =  0.2538580000400543\n",
      "epoch :  23  /  100  | TL :  0.2297842168615232  | VL :  0.21566509967669845  | VA :  92.80000000000001\n",
      "epoch =  23  step =  50  of total steps  201  loss =  0.21279509365558624\n",
      "epoch =  23  step =  100  of total steps  201  loss =  0.1486455500125885\n",
      "epoch =  23  step =  150  of total steps  201  loss =  0.2632969617843628\n",
      "epoch =  23  step =  200  of total steps  201  loss =  0.16457611322402954\n",
      "epoch :  24  /  100  | TL :  0.21836890515625773  | VL :  0.2090360908769071  | VA :  93.2\n",
      "epoch =  24  step =  50  of total steps  201  loss =  0.11507637053728104\n",
      "epoch =  24  step =  100  of total steps  201  loss =  0.21350052952766418\n",
      "epoch =  24  step =  150  of total steps  201  loss =  0.19461077451705933\n",
      "epoch =  24  step =  200  of total steps  201  loss =  0.0924970805644989\n",
      "epoch :  25  /  100  | TL :  0.23034875096743973  | VL :  0.22059228597208858  | VA :  92.80000000000001\n",
      "epoch =  25  step =  50  of total steps  201  loss =  0.2834708094596863\n",
      "epoch =  25  step =  100  of total steps  201  loss =  0.3684925138950348\n",
      "epoch =  25  step =  150  of total steps  201  loss =  0.25094810128211975\n",
      "epoch =  25  step =  200  of total steps  201  loss =  0.03063058853149414\n",
      "epoch :  26  /  100  | TL :  0.23034042375746058  | VL :  0.20344295306131244  | VA :  93.60000000000001\n",
      "epoch =  26  step =  50  of total steps  201  loss =  0.21462103724479675\n",
      "epoch =  26  step =  100  of total steps  201  loss =  0.2937115430831909\n",
      "epoch =  26  step =  150  of total steps  201  loss =  0.23049980401992798\n",
      "epoch =  26  step =  200  of total steps  201  loss =  0.1635710895061493\n",
      "epoch :  27  /  100  | TL :  0.22604353143949413  | VL :  0.21102545876055956  | VA :  93.0\n",
      "epoch =  27  step =  50  of total steps  201  loss =  0.19217555224895477\n",
      "epoch =  27  step =  100  of total steps  201  loss =  0.2887077331542969\n",
      "epoch =  27  step =  150  of total steps  201  loss =  0.11672937124967575\n",
      "epoch =  27  step =  200  of total steps  201  loss =  0.1992422491312027\n",
      "epoch :  28  /  100  | TL :  0.22197365047252593  | VL :  0.2104673981666565  | VA :  92.80000000000001\n",
      "epoch =  28  step =  50  of total steps  201  loss =  0.31323811411857605\n",
      "epoch =  28  step =  100  of total steps  201  loss =  0.28009775280952454\n",
      "epoch =  28  step =  150  of total steps  201  loss =  0.22451332211494446\n",
      "epoch =  28  step =  200  of total steps  201  loss =  0.37403005361557007\n",
      "epoch :  29  /  100  | TL :  0.22964832032868518  | VL :  0.22578905010595918  | VA :  92.80000000000001\n",
      "epoch =  29  step =  50  of total steps  201  loss =  0.11209139972925186\n",
      "epoch =  29  step =  100  of total steps  201  loss =  0.31813567876815796\n",
      "epoch =  29  step =  150  of total steps  201  loss =  0.2965307831764221\n",
      "epoch =  29  step =  200  of total steps  201  loss =  0.27872875332832336\n",
      "epoch :  30  /  100  | TL :  0.22967274550033445  | VL :  0.20579060819000006  | VA :  93.2\n",
      "epoch =  30  step =  50  of total steps  201  loss =  0.35499805212020874\n",
      "epoch =  30  step =  100  of total steps  201  loss =  0.0613543838262558\n",
      "epoch =  30  step =  150  of total steps  201  loss =  0.43352049589157104\n",
      "epoch =  30  step =  200  of total steps  201  loss =  0.39686667919158936\n",
      "epoch :  31  /  100  | TL :  0.225666430347891  | VL :  0.2129749502055347  | VA :  92.60000000000001\n",
      "epoch =  31  step =  50  of total steps  201  loss =  0.14546364545822144\n",
      "epoch =  31  step =  100  of total steps  201  loss =  0.2166927009820938\n",
      "epoch =  31  step =  150  of total steps  201  loss =  0.2029963731765747\n",
      "epoch =  31  step =  200  of total steps  201  loss =  0.1869591474533081\n",
      "epoch :  32  /  100  | TL :  0.22809847485070206  | VL :  0.21089945826679468  | VA :  93.0\n",
      "epoch =  32  step =  50  of total steps  201  loss =  0.1965678334236145\n",
      "epoch =  32  step =  100  of total steps  201  loss =  0.09901664406061172\n",
      "epoch =  32  step =  150  of total steps  201  loss =  0.21999090909957886\n",
      "epoch =  32  step =  200  of total steps  201  loss =  0.2442198246717453\n",
      "epoch :  33  /  100  | TL :  0.22197863491094527  | VL :  0.23099641920998693  | VA :  92.4\n",
      "epoch =  33  step =  50  of total steps  201  loss =  0.3653484284877777\n",
      "epoch =  33  step =  100  of total steps  201  loss =  0.11656630039215088\n",
      "epoch =  33  step =  150  of total steps  201  loss =  0.34430477023124695\n",
      "epoch =  33  step =  200  of total steps  201  loss =  0.3477349281311035\n",
      "epoch :  34  /  100  | TL :  0.21995753474274085  | VL :  0.22129679331555963  | VA :  93.0\n",
      "epoch =  34  step =  50  of total steps  201  loss =  0.1861451119184494\n",
      "epoch =  34  step =  100  of total steps  201  loss =  0.2736620008945465\n",
      "epoch =  34  step =  150  of total steps  201  loss =  0.20288968086242676\n",
      "epoch =  34  step =  200  of total steps  201  loss =  0.23123054206371307\n",
      "epoch :  35  /  100  | TL :  0.2158615339677132  | VL :  0.20560796977952123  | VA :  93.8\n",
      "epoch =  35  step =  50  of total steps  201  loss =  0.344775915145874\n",
      "epoch =  35  step =  100  of total steps  201  loss =  0.35674190521240234\n",
      "epoch =  35  step =  150  of total steps  201  loss =  0.10270797461271286\n",
      "epoch =  35  step =  200  of total steps  201  loss =  0.1683659553527832\n",
      "epoch :  36  /  100  | TL :  0.23095136784498965  | VL :  0.1976895914413035  | VA :  93.60000000000001\n",
      "epoch =  36  step =  50  of total steps  201  loss =  0.26144957542419434\n",
      "epoch =  36  step =  100  of total steps  201  loss =  0.17073455452919006\n",
      "epoch =  36  step =  150  of total steps  201  loss =  0.24032054841518402\n",
      "epoch =  36  step =  200  of total steps  201  loss =  0.256711483001709\n",
      "epoch :  37  /  100  | TL :  0.21890872589020588  | VL :  0.22587262652814388  | VA :  92.0\n",
      "epoch =  37  step =  50  of total steps  201  loss =  0.15837545692920685\n",
      "epoch =  37  step =  100  of total steps  201  loss =  0.24520274996757507\n",
      "epoch =  37  step =  150  of total steps  201  loss =  0.30873435735702515\n",
      "epoch =  37  step =  200  of total steps  201  loss =  0.1842978298664093\n",
      "epoch :  38  /  100  | TL :  0.22020536244137964  | VL :  0.20285974629223347  | VA :  93.60000000000001\n",
      "epoch =  38  step =  50  of total steps  201  loss =  0.3237845301628113\n",
      "epoch =  38  step =  100  of total steps  201  loss =  0.3796781301498413\n",
      "epoch =  38  step =  150  of total steps  201  loss =  0.3336086869239807\n",
      "epoch =  38  step =  200  of total steps  201  loss =  0.2904944121837616\n",
      "epoch :  39  /  100  | TL :  0.2288036356999803  | VL :  0.21355656068772078  | VA :  93.4\n",
      "epoch =  39  step =  50  of total steps  201  loss =  0.25732648372650146\n",
      "epoch =  39  step =  100  of total steps  201  loss =  0.1593049168586731\n",
      "epoch =  39  step =  150  of total steps  201  loss =  0.08520694822072983\n",
      "epoch =  39  step =  200  of total steps  201  loss =  0.26996099948883057\n",
      "epoch :  40  /  100  | TL :  0.22804281508448113  | VL :  0.20556521648541093  | VA :  93.0\n",
      "epoch =  40  step =  50  of total steps  201  loss =  0.22465980052947998\n",
      "epoch =  40  step =  100  of total steps  201  loss =  0.12247709184885025\n",
      "epoch =  40  step =  150  of total steps  201  loss =  0.2194310575723648\n",
      "epoch =  40  step =  200  of total steps  201  loss =  0.15610189735889435\n",
      "epoch :  41  /  100  | TL :  0.22599742822905086  | VL :  0.20361698511987925  | VA :  93.2\n",
      "epoch =  41  step =  50  of total steps  201  loss =  0.18349871039390564\n",
      "epoch =  41  step =  100  of total steps  201  loss =  0.36401087045669556\n",
      "epoch =  41  step =  150  of total steps  201  loss =  0.16551190614700317\n",
      "epoch =  41  step =  200  of total steps  201  loss =  0.4256124198436737\n",
      "epoch :  42  /  100  | TL :  0.2139500911481938  | VL :  0.20956390630453825  | VA :  93.0\n",
      "epoch =  42  step =  50  of total steps  201  loss =  0.1806209683418274\n",
      "epoch =  42  step =  100  of total steps  201  loss =  0.10414179414510727\n",
      "epoch =  42  step =  150  of total steps  201  loss =  0.3842158913612366\n",
      "epoch =  42  step =  200  of total steps  201  loss =  0.16441592574119568\n",
      "epoch :  43  /  100  | TL :  0.2172781045164042  | VL :  0.20362974982708693  | VA :  93.8\n",
      "epoch =  43  step =  50  of total steps  201  loss =  0.2632856070995331\n",
      "epoch =  43  step =  100  of total steps  201  loss =  0.35145777463912964\n",
      "epoch =  43  step =  150  of total steps  201  loss =  0.0932120680809021\n",
      "epoch =  43  step =  200  of total steps  201  loss =  0.1606590896844864\n",
      "epoch :  44  /  100  | TL :  0.21670372085414122  | VL :  0.2213744418695569  | VA :  93.0\n",
      "epoch =  44  step =  50  of total steps  201  loss =  0.3351249098777771\n",
      "epoch =  44  step =  100  of total steps  201  loss =  0.3722446858882904\n",
      "epoch =  44  step =  150  of total steps  201  loss =  0.13034334778785706\n",
      "epoch =  44  step =  200  of total steps  201  loss =  0.15353786945343018\n",
      "epoch :  45  /  100  | TL :  0.2181326214650377  | VL :  0.22048906376585364  | VA :  92.80000000000001\n",
      "epoch =  45  step =  50  of total steps  201  loss =  0.3352915048599243\n",
      "epoch =  45  step =  100  of total steps  201  loss =  0.3088238835334778\n",
      "epoch =  45  step =  150  of total steps  201  loss =  0.1644262671470642\n",
      "epoch =  45  step =  200  of total steps  201  loss =  0.10344810038805008\n",
      "epoch :  46  /  100  | TL :  0.21565163083633973  | VL :  0.21943769324570894  | VA :  92.80000000000001\n",
      "epoch =  46  step =  50  of total steps  201  loss =  0.10426834970712662\n",
      "epoch =  46  step =  100  of total steps  201  loss =  0.22308982908725739\n",
      "epoch =  46  step =  150  of total steps  201  loss =  0.26826727390289307\n",
      "epoch =  46  step =  200  of total steps  201  loss =  0.12887458503246307\n",
      "epoch :  47  /  100  | TL :  0.2080977896960517  | VL :  0.2306790049187839  | VA :  92.80000000000001\n",
      "epoch =  47  step =  50  of total steps  201  loss =  0.3430963158607483\n",
      "epoch =  47  step =  100  of total steps  201  loss =  0.29293397068977356\n",
      "epoch =  47  step =  150  of total steps  201  loss =  0.1422756165266037\n",
      "epoch =  47  step =  200  of total steps  201  loss =  0.20967553555965424\n",
      "epoch :  48  /  100  | TL :  0.21608938198925845  | VL :  0.20691866660490632  | VA :  93.4\n",
      "epoch =  48  step =  50  of total steps  201  loss =  0.2575813829898834\n",
      "epoch =  48  step =  100  of total steps  201  loss =  0.18993297219276428\n",
      "epoch =  48  step =  150  of total steps  201  loss =  0.2413366585969925\n",
      "epoch =  48  step =  200  of total steps  201  loss =  0.2662010192871094\n",
      "epoch :  49  /  100  | TL :  0.22168623643060822  | VL :  0.2203666465356946  | VA :  92.60000000000001\n",
      "epoch =  49  step =  50  of total steps  201  loss =  0.07282127439975739\n",
      "epoch =  49  step =  100  of total steps  201  loss =  0.3335210978984833\n",
      "epoch =  49  step =  150  of total steps  201  loss =  0.14121118187904358\n",
      "epoch =  49  step =  200  of total steps  201  loss =  0.20040209591388702\n",
      "epoch :  50  /  100  | TL :  0.21073319017887115  | VL :  0.20464605884626508  | VA :  93.2\n",
      "epoch =  50  step =  50  of total steps  201  loss =  0.1587088406085968\n"
     ]
    }
   ],
   "source": [
    "experiment = Experiment(api_key=\"IOZ5docSriEdGRdQmdXQn9kpu\",\n",
    "                        project_name=\"kd0\", workspace=\"akshaykvnit\")\n",
    "experiment.log_parameters(hyper_params)\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr = hyper_params[\"learning_rate\"])\n",
    "total_step = len(data.train_ds) // hyper_params[\"batch_size\"]\n",
    "train_loss_list = list()\n",
    "val_loss_list = list()\n",
    "min_val = 0\n",
    "for epoch in range(hyper_params[\"num_epochs\"]):\n",
    "    trn = []\n",
    "    net.train()\n",
    "    for i, (images, labels) in enumerate(data.train_dl) :\n",
    "        if torch.cuda.is_available():\n",
    "            images = torch.autograd.Variable(images).cuda().float()\n",
    "            labels = torch.autograd.Variable(labels).cuda()\n",
    "        else : \n",
    "            images = torch.autograd.Variable(images).float()\n",
    "            labels = torch.autograd.Variable(labels)\n",
    "\n",
    "        y_pred = net(images)\n",
    "\n",
    "        loss = F.cross_entropy(y_pred, labels)\n",
    "        trn.append(loss.item())\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "#         torch.nn.utils.clip_grad_value_(net.parameters(), 10)\n",
    "        optimizer.step()\n",
    "\n",
    "        if i % 50 == 49 :\n",
    "            print('epoch = ', epoch, ' step = ', i + 1, ' of total steps ', total_step, ' loss = ', loss.item())\n",
    "\n",
    "    train_loss = (sum(trn) / len(trn))\n",
    "    train_loss_list.append(train_loss)\n",
    "\n",
    "    net.eval()\n",
    "    val = []\n",
    "    with torch.no_grad() :\n",
    "        for i, (images, labels) in enumerate(data.valid_dl) :\n",
    "            if torch.cuda.is_available():\n",
    "                images = torch.autograd.Variable(images).cuda().float()\n",
    "                labels = torch.autograd.Variable(labels).cuda()\n",
    "            else : \n",
    "                images = torch.autograd.Variable(images).float()\n",
    "                labels = torch.autograd.Variable(labels)\n",
    "\n",
    "            # Forward pass\n",
    "            y_pred = net(images)\n",
    "            \n",
    "            loss = F.cross_entropy(y_pred, labels)\n",
    "            val.append(loss.item())\n",
    "\n",
    "    val_loss = sum(val) / len(val)\n",
    "    val_loss_list.append(val_loss)\n",
    "    val_acc = _get_accuracy(data.valid_dl, net)\n",
    "\n",
    "    print('epoch : ', epoch + 1, ' / ', hyper_params[\"num_epochs\"], ' | TL : ', train_loss, ' | VL : ', val_loss, ' | VA : ', val_acc * 100)\n",
    "    experiment.log_metric(\"train_loss\", train_loss)\n",
    "    experiment.log_metric(\"val_loss\", val_loss)\n",
    "    experiment.log_metric(\"val_acc\", val_acc)\n",
    "\n",
    "    if (val_acc * 100) > min_val :\n",
    "        print('saving model')\n",
    "        min_val = val_acc * 100\n",
    "        torch.save(net.state_dict(), '../saved_models/classifier/model0.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.942\n",
      "0.936\n"
     ]
    }
   ],
   "source": [
    "net.cpu()\n",
    "net.load_state_dict(torch.load('../saved_models/classifier/model0.pt', map_location = 'cpu'))\n",
    "net.cuda()\n",
    "\n",
    "print(_get_accuracy(data.valid_dl, net))\n",
    "print(_get_accuracy(data.valid_dl, learn.model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ak_fastai)",
   "language": "python",
   "name": "ak_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
